mirror of
https://github.com/frankwxu/mobile-pii-discovery-agent.git
synced 2026-02-20 13:40:41 +00:00
724 lines
28 KiB
Plaintext
724 lines
28 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "a10c9a6a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"OK\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"from langchain_core.messages import HumanMessage\n",
|
|
"from sql_utils import *\n",
|
|
"from datetime import datetime, timezone\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"load_dotenv() # This looks for the .env file and loads it into os.environ\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(\n",
|
|
" model=\"gpt-4o-mini\", # recommended for tools + cost\n",
|
|
" api_key=os.environ[\"API_KEY\"],\n",
|
|
" temperature=0,\n",
|
|
" seed=100,\n",
|
|
")\n",
|
|
"\n",
|
|
"response = llm.invoke([\n",
|
|
" HumanMessage(content=\"Reply with exactly: OK\")\n",
|
|
"])\n",
|
|
"\n",
|
|
"print(response.content)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "48eda3ec",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Core Python\n",
|
|
"import sqlite3\n",
|
|
"import re\n",
|
|
"import json\n",
|
|
"from typing import TypedDict, Optional, List, Annotated\n",
|
|
"from langgraph.graph.message import add_messages\n",
|
|
"\n",
|
|
"# LangChain / LangGraph\n",
|
|
"from langchain_core.tools import tool\n",
|
|
"from langchain_core.messages import (\n",
|
|
" HumanMessage,\n",
|
|
" AIMessage,\n",
|
|
" SystemMessage\n",
|
|
")\n",
|
|
"from langchain.agents import create_agent\n",
|
|
"from langgraph.graph import StateGraph, END\n",
|
|
"from langgraph.graph.message import MessagesState\n",
|
|
"from sql_utils import *\n",
|
|
"\n",
|
|
"\n",
|
|
"@tool\n",
|
|
"def list_tables() -> str:\n",
|
|
" \"\"\"\n",
|
|
" List non-empty user tables in the SQLite database.\n",
|
|
" \"\"\"\n",
|
|
" IDENT_RE = re.compile(r\"^[A-Za-z_][A-Za-z0-9_]*$\")\n",
|
|
" conn = sqlite3.connect(DB_PATH)\n",
|
|
" try:\n",
|
|
" cur = conn.cursor()\n",
|
|
" cur.execute(\"\"\"\n",
|
|
" SELECT name\n",
|
|
" FROM sqlite_master\n",
|
|
" WHERE type='table' AND name NOT LIKE 'sqlite_%'\n",
|
|
" ORDER BY name\n",
|
|
" \"\"\")\n",
|
|
" tables = [r[0] for r in cur.fetchall()]\n",
|
|
"\n",
|
|
" nonempty = []\n",
|
|
" for t in tables:\n",
|
|
" # If your DB has weird table names, remove this guard,\n",
|
|
" # but keep the quoting below.\n",
|
|
" if not IDENT_RE.match(t):\n",
|
|
" continue\n",
|
|
" try:\n",
|
|
" cur.execute(f'SELECT 1 FROM \"{t}\" LIMIT 1;')\n",
|
|
" if cur.fetchone() is not None:\n",
|
|
" nonempty.append(t)\n",
|
|
" except sqlite3.Error:\n",
|
|
" continue\n",
|
|
"\n",
|
|
" return \", \".join(nonempty)\n",
|
|
" finally:\n",
|
|
" conn.close()\n",
|
|
"\n",
|
|
"\n",
|
|
"@tool\n",
|
|
"def get_schema(table: str) -> str:\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" Return column names and types for a table.\n",
|
|
" \"\"\"\n",
|
|
" conn = sqlite3.connect(DB_PATH)\n",
|
|
" cur = conn.cursor()\n",
|
|
" cur.execute(f\"PRAGMA table_info('{table}')\")\n",
|
|
" cols = cur.fetchall()\n",
|
|
" conn.close()\n",
|
|
" return \", \".join(f\"{c[1]} {c[2]}\" for c in cols)\n",
|
|
"\n",
|
|
"\n",
|
|
"@tool\n",
|
|
"def exec_sql(\n",
|
|
" query: str,\n",
|
|
" db_path: str,\n",
|
|
" top_n: int = 10,\n",
|
|
" verbose: bool = True,\n",
|
|
") -> dict:\n",
|
|
" \"\"\"\n",
|
|
" Execute a UNION ALL query by splitting into individual SELECT statements.\n",
|
|
" Runs each SELECT with LIMIT top_n, skipping any SELECT that errors.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" rows_all: list of rows (combined from all successful chunks)\n",
|
|
" column_names: list of column names (deduped), prefixed as table.column when possible\n",
|
|
" \"\"\"\n",
|
|
" \n",
|
|
" query_text = normalize_sql(query)\n",
|
|
" selects = split_union_selects(query_text)\n",
|
|
"\n",
|
|
" rows_all = []\n",
|
|
" column_names = []\n",
|
|
"\n",
|
|
" conn = sqlite3.connect(db_path)\n",
|
|
" conn.create_function(\"REGEXP\", 2, regexp)\n",
|
|
" cur = conn.cursor()\n",
|
|
"\n",
|
|
" try:\n",
|
|
" for i, select_sql in enumerate(selects, 1):\n",
|
|
" select_sql_clean = select_sql.rstrip().rstrip(\";\")\n",
|
|
" select_sql_run = f\"{select_sql_clean}\\nLIMIT {top_n};\"\n",
|
|
"\n",
|
|
" if verbose:\n",
|
|
" print(f\"[EXECUTE] chunk {i}/{len(selects)} LIMIT {top_n}\")\n",
|
|
" # print(select_sql_run) # uncomment to print full SQL\n",
|
|
"\n",
|
|
" try:\n",
|
|
" cur.execute(select_sql_run)\n",
|
|
" chunk = cur.fetchall()\n",
|
|
" rows_all.extend(chunk)\n",
|
|
"\n",
|
|
" # collect columns only if chunk succeeded\n",
|
|
" tbl = extract_single_table(select_sql_clean)\n",
|
|
" for col in extract_select_columns(select_sql_clean):\n",
|
|
" name = f\"{tbl}.{col}\" if (tbl and \".\" not in col) else col\n",
|
|
" if name not in column_names:\n",
|
|
" column_names.append(name)\n",
|
|
"\n",
|
|
" except Exception as e:\n",
|
|
" if verbose:\n",
|
|
" print(f\"[SQL ERROR] Skipping chunk {i}: {e}\")\n",
|
|
"\n",
|
|
" finally:\n",
|
|
" conn.close()\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"rows\": rows_all,\n",
|
|
" \"columns\": column_names\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"from typing import Any, TypedDict\n",
|
|
"class EvidenceState(TypedDict):\n",
|
|
" database_name: str\n",
|
|
" messages: Annotated[list, add_messages]\n",
|
|
" attempt: int\n",
|
|
" max_attempts: int\n",
|
|
" phase: str # \"exploration\" | \"extraction\"\n",
|
|
"\n",
|
|
" # SQL separation\n",
|
|
" exploration_sql: Optional[str]\n",
|
|
" extraction_sql: Optional[str]\n",
|
|
"\n",
|
|
" rows: Optional[List]\n",
|
|
" classification: Optional[dict]\n",
|
|
" evidence: Optional[List[str]]\n",
|
|
"\n",
|
|
" source_columns: Optional[List[str]] \n",
|
|
" entity_config: dict[str, Any]\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_explore_system(type, regex):\n",
|
|
" return SystemMessage(\n",
|
|
" content=(\n",
|
|
" \"You are a SQL planner. You are provided app databases that are extracted from Android or iPhone devices.\\n\"\n",
|
|
" \"apps include Android Whatsapp, Snapchat, Telegram, Google Map, Samsung Internet, iPhone Contacts, Messages, Safari, and Calendar.\\n\"\n",
|
|
" f\"Goal: discover if any column of databases contains possible {type}.\\n\\n\"\n",
|
|
" \"Rules:\\n\"\n",
|
|
" \"- Use 'REGEXP' for pattern matching.\\n\"\n",
|
|
" f\"- Example: SELECT col FROM table WHERE col REGEXP '{regex}' \\n\"\n",
|
|
" \"- Table and col names can be used as hints to find solutions. \\n\"\n",
|
|
" \"- Include the tables and columns even there is a small possility of containing solutions.\\n\"\n",
|
|
" \"- Pay attention to messages, chats, or other text fields.\\n\"\n",
|
|
" \"- Validate your SQL and make sure all tables and columns do exist.\\n\"\n",
|
|
" \"- If multiple SQL statements are provided, combine them using UNION ALL. \\n\"\n",
|
|
" f\"- Example: SELECT col1 FROM table1 WHERE col1 REGEXP '{regex}' UNION ALL SELECT col2 FROM table2 WHERE col2 REGEXP '{regex}'\\n\"\n",
|
|
" \"- Make sure all tables and columns do exist before return SQL. \\n\"\n",
|
|
" \"- Return ONLY SQL.\"\n",
|
|
" )\n",
|
|
" )\n",
|
|
"\n",
|
|
"def planner(state: EvidenceState):\n",
|
|
" # Extraction upgrade path\n",
|
|
" if state[\"phase\"] == \"extraction\" and state.get(\"exploration_sql\"):\n",
|
|
" extraction_sql = upgrade_sql_remove_limit(state[\"exploration_sql\"])\n",
|
|
" return {\n",
|
|
" \"messages\": [AIMessage(content=extraction_sql)],\n",
|
|
" \"extraction_sql\": extraction_sql\n",
|
|
" }\n",
|
|
"\n",
|
|
" # Optional safety stop inside planner too\n",
|
|
" if state.get(\"phase\") == \"exploration\" and state.get(\"attempt\", 0) >= state.get(\"max_attempts\", 0):\n",
|
|
" return {\n",
|
|
" \"phase\": \"done\",\n",
|
|
" \"messages\": [AIMessage(content=\"STOP: max attempts reached in planner.\")]\n",
|
|
" }\n",
|
|
" # Original discovery logic\n",
|
|
" tables = list_tables.invoke({})\n",
|
|
" config = state[\"entity_config\"]\n",
|
|
"\n",
|
|
" base_system = get_explore_system(\n",
|
|
" f\"{config.get('type','')}:{config.get('desc','')}\".strip(),\n",
|
|
" config[\"regex\"]\n",
|
|
" )\n",
|
|
"\n",
|
|
" grounded_content = (\n",
|
|
" f\"{base_system.content}\\n\\n\"\n",
|
|
" f\"EXISTING TABLES: {tables}\\n\"\n",
|
|
" f\"CURRENT PHASE: {state['phase']}\\n\"\n",
|
|
" \"CRITICAL: Do not query non-existent tables.\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" agent = create_agent(llm, [list_tables,get_schema])\n",
|
|
" \n",
|
|
" result = agent.invoke({\n",
|
|
" \"messages\": [\n",
|
|
" SystemMessage(content=grounded_content),\n",
|
|
" state[\"messages\"][0] # original user request only\n",
|
|
" ]\n",
|
|
" })\n",
|
|
"\n",
|
|
" exploration_sql = normalize_sql(result[\"messages\"][-1].content)\n",
|
|
"\n",
|
|
" attempt = state[\"attempt\"] + 1 if state[\"phase\"] == \"exploration\" else state[\"attempt\"]\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"messages\": [AIMessage(content=exploration_sql)],\n",
|
|
" \"exploration_sql\": exploration_sql,\n",
|
|
" \"attempt\": attempt\n",
|
|
" }\n",
|
|
"\n",
|
|
"def sql_execute(state: EvidenceState):\n",
|
|
" top_n=10\n",
|
|
" # Choose SQL based on phase\n",
|
|
" if state[\"phase\"] == \"extraction\":\n",
|
|
" sql_to_run = state.get(\"extraction_sql\")\n",
|
|
" top_n=10000\n",
|
|
" else: # \"exploration\"\n",
|
|
" sql_to_run = state.get(\"exploration_sql\")\n",
|
|
" top_n=10\n",
|
|
"\n",
|
|
" if not sql_to_run:\n",
|
|
" print(\"[SQL EXEC] No SQL provided for this phase\")\n",
|
|
" return {\n",
|
|
" \"rows\": [],\n",
|
|
" \"messages\": [AIMessage(content=\"No SQL to execute\")]\n",
|
|
" }\n",
|
|
"\n",
|
|
" # Execute\n",
|
|
" result = result = exec_sql.invoke({\n",
|
|
" \"query\": sql_to_run,\n",
|
|
" \"db_path\": state[\"database_name\"],\n",
|
|
" \"top_n\": top_n,\n",
|
|
" \"verbose\": False\n",
|
|
"})\n",
|
|
"\n",
|
|
"\n",
|
|
" rows = result.get(\"rows\", [])\n",
|
|
" cols = result.get(\"columns\", [])\n",
|
|
"\n",
|
|
" print(f\"[SQL EXEC] Retrieved {len(rows)} rows\")\n",
|
|
" \n",
|
|
" # for i, r in enumerate(rows, 1):\n",
|
|
" # print(f\" row[{i}]: {r}\")\n",
|
|
"\n",
|
|
" updates = {\n",
|
|
" \"rows\": rows,\n",
|
|
" \"messages\": [AIMessage(content=f\"Retrieved {len(rows)} rows\")]\n",
|
|
" }\n",
|
|
"\n",
|
|
" # Track columns only during extraction (provenance)\n",
|
|
" if state[\"phase\"] == \"extraction\":\n",
|
|
" updates[\"source_columns\"] = cols\n",
|
|
" print(f\"[TRACKING] Saved source columns: {cols}\")\n",
|
|
"\n",
|
|
" return updates\n",
|
|
"\n",
|
|
" \n",
|
|
"\n",
|
|
"def classify(state: EvidenceState):\n",
|
|
" # 1. Prepare the text sample for the LLM\n",
|
|
" text = rows_to_text(state[\"rows\"], limit=15)\n",
|
|
" \n",
|
|
" # 2. Get the pii-specific system message\n",
|
|
" config= state[\"entity_config\"]\n",
|
|
" pii_desc= f\"{config.get('type','')}:{config.get('desc','')}\".strip()\n",
|
|
" system_message = SystemMessage(\n",
|
|
" content=(\n",
|
|
" f\"Decide whether the text contains {pii_desc}.\\n\"\n",
|
|
" \"Return ONLY a JSON object with these keys:\\n\"\n",
|
|
" \"{ \\\"found\\\": true/false, \\\"confidence\\\": number, \\\"reason\\\": \\\"string\\\" }\"\n",
|
|
" )\n",
|
|
" )\n",
|
|
"\n",
|
|
" # 3. Invoke the LLM\n",
|
|
" result = llm.invoke([\n",
|
|
" system_message,\n",
|
|
" HumanMessage(content=f\"Data to analyze:\\n{text}\")\n",
|
|
" ]).content\n",
|
|
" \n",
|
|
"# 4. Parse the decision\n",
|
|
" decision = safe_json_loads(\n",
|
|
" result,\n",
|
|
" default={\"found\": False, \"confidence\": 0.0, \"reason\": \"parse failure\"}\n",
|
|
" )\n",
|
|
"\n",
|
|
" # print(\"[CLASSIFY]\", decision)\n",
|
|
" return {\"classification\": decision}\n",
|
|
"\n",
|
|
"\n",
|
|
"def switch_to_extraction(state: EvidenceState):\n",
|
|
" print(\"[PHASE] discovery → extraction\")\n",
|
|
" return {\"phase\": \"extraction\"}\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract(state: EvidenceState):\n",
|
|
" text = rows_to_text(state[\"rows\"])\n",
|
|
" # print(f\"Check last 100 characts : {text[:-100]}\")\n",
|
|
" desc = state[\"entity_config\"].get(\"desc\", \"PII\")\n",
|
|
" system = f\"Identify real {desc} from text and normalize them. Return ONLY a JSON array of strings.\\n\"\n",
|
|
"\n",
|
|
" result = llm.invoke([SystemMessage(content=system), HumanMessage(content=text)]).content\n",
|
|
" return {\"evidence\": safe_json_loads(result, default=[])}\n",
|
|
"\n",
|
|
"\n",
|
|
"def next_step(state: EvidenceState):\n",
|
|
" # Once in extraction phase, extract and stop\n",
|
|
" if state[\"phase\"] == \"extraction\":\n",
|
|
" return \"do_extract\"\n",
|
|
"\n",
|
|
" c = state[\"classification\"]\n",
|
|
"\n",
|
|
" if c[\"found\"] and c[\"confidence\"] >= 0.6:\n",
|
|
" return \"to_extraction\"\n",
|
|
"\n",
|
|
" if not c[\"found\"] and c[\"confidence\"] >= 0.6:\n",
|
|
" return \"stop_none\"\n",
|
|
"\n",
|
|
" if state[\"attempt\"] >= state[\"max_attempts\"]:\n",
|
|
" return \"stop_limit\"\n",
|
|
"\n",
|
|
" return \"replan\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "0f5259d7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Any, Dict\n",
|
|
"from functools import partial\n",
|
|
"from langgraph.graph import StateGraph, END\n",
|
|
"\n",
|
|
"def observe(\n",
|
|
" state: \"EvidenceState\",\n",
|
|
" enable_observe: bool = False,\n",
|
|
" label: str = \"OBSERVE\",\n",
|
|
" sample_rows: int = 20,\n",
|
|
" sample_evidence: int = 10,\n",
|
|
") -> Dict[str, Any]:\n",
|
|
" \"\"\"\n",
|
|
" Debug / inspection node.\n",
|
|
" Does NOT modify state.\n",
|
|
" If enable_observe is False, prints nothing.\n",
|
|
" \"\"\"\n",
|
|
" if not enable_observe:\n",
|
|
" return {}\n",
|
|
"\n",
|
|
" print(f\"\\n=== STATE SNAPSHOT [{label}] ===\")\n",
|
|
"\n",
|
|
" # Messages\n",
|
|
" print(\"\\n--- MESSAGES ---\")\n",
|
|
" for i, m in enumerate(state.get(\"messages\", [])):\n",
|
|
" mtype = getattr(m, \"type\", \"unknown\")\n",
|
|
" mcontent = getattr(m, \"content\", str(m))\n",
|
|
" print(f\"{i}: {str(mtype).upper()} -> {mcontent}\")\n",
|
|
"\n",
|
|
" # Metadata\n",
|
|
" print(\"\\n--- BEGIN METADATA ---\")\n",
|
|
" print(f\"attempt : {state.get('attempt')}\")\n",
|
|
" print(f\"max_attempts : {state.get('max_attempts')}\")\n",
|
|
" print(f\"phase : {state.get('phase')}\")\n",
|
|
" print(f\"PII type : {(state.get('entity_config') or {}).get('type')}\")\n",
|
|
"\n",
|
|
" # SQL separation\n",
|
|
" print(f\"exploration_sql : {state.get('exploration_sql')}\")\n",
|
|
" print(f\"extraction_sql : {state.get('extraction_sql')}\")\n",
|
|
"\n",
|
|
" # Outputs\n",
|
|
" rows = state.get(\"rows\") or []\n",
|
|
" print(f\"rows_count : {len(rows)}\")\n",
|
|
" print(f\"rows_sample : {rows[:sample_rows] if rows else []}\")\n",
|
|
"\n",
|
|
" evidence = state.get(\"evidence\") or []\n",
|
|
" print(f\"classification : {state.get('classification')}\")\n",
|
|
" print(f\"evidence_count : {len(evidence)}\")\n",
|
|
" print(f\"evidence_sample : {evidence[:sample_evidence]}\")\n",
|
|
"\n",
|
|
" print(f\"source_columns : {state.get('source_columns')}\")\n",
|
|
" print(\"\\n--- END METADATA ---\")\n",
|
|
"\n",
|
|
" return {} # no-op update\n",
|
|
"\n",
|
|
"\n",
|
|
"# ---- Build graph with an enable flag ----\n",
|
|
"def build_graph(enable_observe: bool = False):\n",
|
|
" graph = StateGraph(EvidenceState)\n",
|
|
"\n",
|
|
" # Nodes\n",
|
|
" graph.add_node(\"planner\", planner)\n",
|
|
"\n",
|
|
" # Wrap observe so it matches (state) -> update\n",
|
|
" graph.add_node(\"observe_plan\", partial(observe, enable_observe=enable_observe, label=\"PLAN\"))\n",
|
|
" graph.add_node(\"execute\", sql_execute)\n",
|
|
" graph.add_node(\"observe_execution\", partial(observe, enable_observe=enable_observe, label=\"EXECUTION\"))\n",
|
|
" graph.add_node(\"classify\", classify)\n",
|
|
" graph.add_node(\"observe_classify\", partial(observe, enable_observe=enable_observe, label=\"CLASSIFY\"))\n",
|
|
" graph.add_node(\"switch_phase\", switch_to_extraction)\n",
|
|
" graph.add_node(\"extract\", extract)\n",
|
|
" graph.add_node(\"observe_final\", partial(observe, enable_observe=enable_observe, label=\"FINAL\"))\n",
|
|
"\n",
|
|
" graph.set_entry_point(\"planner\")\n",
|
|
"\n",
|
|
" # --- FLOW ---\n",
|
|
" graph.add_edge(\"planner\", \"observe_plan\")\n",
|
|
" graph.add_edge(\"observe_plan\", \"execute\")\n",
|
|
"\n",
|
|
" graph.add_edge(\"execute\", \"observe_execution\")\n",
|
|
" graph.add_edge(\"observe_execution\", \"classify\")\n",
|
|
"\n",
|
|
" graph.add_edge(\"classify\", \"observe_classify\")\n",
|
|
"\n",
|
|
" graph.add_conditional_edges(\n",
|
|
" \"observe_classify\",\n",
|
|
" next_step,\n",
|
|
" {\n",
|
|
" \"to_extraction\": \"switch_phase\",\n",
|
|
" \"do_extract\": \"extract\",\n",
|
|
" \"replan\": \"planner\",\n",
|
|
" \"stop_none\": END,\n",
|
|
" \"stop_limit\": END,\n",
|
|
" },\n",
|
|
" )\n",
|
|
"\n",
|
|
" graph.add_edge(\"switch_phase\", \"planner\")\n",
|
|
" graph.add_edge(\"extract\", \"observe_final\")\n",
|
|
" graph.add_edge(\"observe_final\", END)\n",
|
|
"\n",
|
|
" return graph.compile()\n",
|
|
"\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "655e0915",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Will process 2 databases (from db_files list).\n",
|
|
"enable_observe: False\n",
|
|
"pii_targets: ['EMAIL', 'PHONE', 'USERNAME', 'PERSON_NAME', 'POSTAL_ADDRESS']\n",
|
|
"\n",
|
|
"Processing DB: selectedDBs\\test2.db\n",
|
|
" Processing: EMAIL\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.email']\n",
|
|
" Processing: PHONE\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.phone']\n",
|
|
" Processing: USERNAME\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.username']\n",
|
|
" Processing: PERSON_NAME\n",
|
|
"[SQL EXEC] Retrieved 30 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 30 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.first_name', 'users.last_name', 'users.username']\n",
|
|
" Processing: POSTAL_ADDRESS\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 10 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.street', 'users.city', 'users.state', 'users.zip_code', 'users.phone']\n",
|
|
"Wrote: I:\\project2026\\llmagent\\batch_results\\PII_test2_20260202T021704Z.jsonl\n",
|
|
"\n",
|
|
"Processing DB: selectedDBs\\users.db\n",
|
|
" Processing: EMAIL\n",
|
|
"[SQL EXEC] Retrieved 3 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 3 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.message']\n",
|
|
" Processing: PHONE\n",
|
|
"[SQL EXEC] Retrieved 0 rows\n",
|
|
" Processing: USERNAME\n",
|
|
"[SQL EXEC] Retrieved 6 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 6 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.first_name', 'users.message']\n",
|
|
" Processing: PERSON_NAME\n",
|
|
"[SQL EXEC] Retrieved 6 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 6 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.first_name', 'users.message']\n",
|
|
" Processing: POSTAL_ADDRESS\n",
|
|
"[SQL EXEC] Retrieved 1 rows\n",
|
|
"[PHASE] discovery → extraction\n",
|
|
"[SQL EXEC] Retrieved 1 rows\n",
|
|
"[TRACKING] Saved source columns: ['users.message']\n",
|
|
"Wrote: I:\\project2026\\llmagent\\batch_results\\PII_users_20260202T021737Z.jsonl\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def run_batch(db_paths, pii_targets, pii_config, app):\n",
|
|
" all_results = []\n",
|
|
"\n",
|
|
" for p in db_paths:\n",
|
|
" db_path = str(p)\n",
|
|
"\n",
|
|
" # If your tools rely on global DB_PATH, keep this line.\n",
|
|
" # If you refactor tools to use state[\"database_name\"], you can remove it.\n",
|
|
" global DB_PATH\n",
|
|
" DB_PATH = db_path\n",
|
|
"\n",
|
|
" print(f\"\\nProcessing: {db_path}\")\n",
|
|
"\n",
|
|
" for target in pii_targets:\n",
|
|
" entity_config = pii_config[target]\n",
|
|
" print(f\" Processing: {target}\")\n",
|
|
"\n",
|
|
" result = app.invoke({\n",
|
|
" \"database_name\": db_path,\n",
|
|
" \"messages\": [HumanMessage(content=f\"Find {entity_config['type'].strip()}, {entity_config['desc'].strip()}, in the database.\")],\n",
|
|
" \"attempt\": 1,\n",
|
|
" \"max_attempts\": 2,\n",
|
|
" \"phase\": \"exploration\",\n",
|
|
" \"entity_config\": entity_config,\n",
|
|
" \"exploration_sql\": None,\n",
|
|
" \"extraction_sql\": None,\n",
|
|
" \"rows\": None,\n",
|
|
" \"classification\": None,\n",
|
|
" \"evidence\": [],\n",
|
|
" \"source_columns\": []\n",
|
|
" })\n",
|
|
"\n",
|
|
" evidence = result.get(\"evidence\", [])\n",
|
|
" source_columns = result.get(\"source_columns\", [])\n",
|
|
" raw_rows = result.get(\"rows\", [])\n",
|
|
"\n",
|
|
" all_results.append({\n",
|
|
" \"db_path\": db_path,\n",
|
|
" \"PII_type\": target,\n",
|
|
" \"PII\": evidence,\n",
|
|
" \"Num_of_PII\": len(evidence),\n",
|
|
" \"source_columns\": source_columns,\n",
|
|
" \"Raw_rows_first_100\": raw_rows[:100],\n",
|
|
" \"Total_raw_rows\": len(raw_rows),\n",
|
|
" \"Exploration_sql\": result.get(\"exploration_sql\", \"\"),\n",
|
|
" \"Extraction_sql\": result.get(\"extraction_sql\", \"\")\n",
|
|
" })\n",
|
|
"\n",
|
|
" return all_results\n",
|
|
"\n",
|
|
"\n",
|
|
"def run_batch(db_paths, pii_targets, pii_config, app, out_dir: Path):\n",
|
|
" \"\"\"\n",
|
|
" Process databases one-by-one and write one output file per database.\n",
|
|
" \"\"\"\n",
|
|
" for p in db_paths:\n",
|
|
" db_path = str(p)\n",
|
|
"\n",
|
|
" # If your tools rely on global DB_PATH, keep this line.\n",
|
|
" global DB_PATH\n",
|
|
" DB_PATH = db_path\n",
|
|
"\n",
|
|
" print(f\"\\nProcessing DB: {db_path}\")\n",
|
|
"\n",
|
|
" db_results = [] # reset per-database\n",
|
|
"\n",
|
|
" for target in pii_targets:\n",
|
|
" entity_config = pii_config[target]\n",
|
|
" print(f\" Processing: {target}\")\n",
|
|
"\n",
|
|
" result = app.invoke({\n",
|
|
" \"database_name\": db_path,\n",
|
|
" \"messages\": [HumanMessage(content=f\"Find {entity_config['type'].strip()}, {entity_config['desc'].strip()}, in the database.\")],\n",
|
|
" \"attempt\": 1,\n",
|
|
" \"max_attempts\": 2,\n",
|
|
" \"phase\": \"exploration\",\n",
|
|
" \"entity_config\": entity_config,\n",
|
|
" \"exploration_sql\": None,\n",
|
|
" \"extraction_sql\": None,\n",
|
|
" \"rows\": None,\n",
|
|
" \"classification\": None,\n",
|
|
" \"evidence\": [],\n",
|
|
" \"source_columns\": []\n",
|
|
" })\n",
|
|
"\n",
|
|
" evidence = result.get(\"evidence\", [])\n",
|
|
" source_columns = result.get(\"source_columns\", [])\n",
|
|
" raw_rows = result.get(\"rows\", [])\n",
|
|
"\n",
|
|
" db_results.append({\n",
|
|
" \"db_path\": db_path,\n",
|
|
" \"PII_type\": target,\n",
|
|
" \"PII\": evidence,\n",
|
|
" \"Num_of_PII\": len(evidence),\n",
|
|
" \"source_columns\": source_columns,\n",
|
|
" \"Raw_rows_first_100\": raw_rows[:100],\n",
|
|
" \"Total_raw_rows\": len(raw_rows),\n",
|
|
" \"Exploration_sql\": result.get(\"exploration_sql\", \"\"),\n",
|
|
" \"Extraction_sql\": result.get(\"extraction_sql\", \"\"),\n",
|
|
" \"PII_Prompt\": entity_config.get(\"desc\", \"\")\n",
|
|
" })\n",
|
|
"\n",
|
|
" # Save per-database output (includes db name + timestamp)\n",
|
|
" save_jsonl(db_results, out_dir, db_path)\n",
|
|
"\n",
|
|
"\n",
|
|
"def main():\n",
|
|
" cfg = load_config_yaml(Path(\"config.yaml\"))\n",
|
|
"\n",
|
|
" DB_DIR = Path(cfg.get(\"db_dir\", \"selectedDBs\"))\n",
|
|
" OUT_DIR = Path(cfg.get(\"out_dir\", \"batch_results\"))\n",
|
|
" OUT_DIR.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
" CONFIG_PY = Path(cfg.get(\"config_py\", \"my_run_config.py\"))\n",
|
|
" vars_ = load_vars_from_py(CONFIG_PY, \"db_files\", \"PII_CONFIG\")\n",
|
|
" db_files = vars_[\"db_files\"]\n",
|
|
" PII_CONFIG = vars_[\"PII_CONFIG\"]\n",
|
|
"\n",
|
|
" PII_TARGETS = cfg.get(\"pii_targets\", list(PII_CONFIG.keys()))\n",
|
|
"\n",
|
|
" db_paths, missing, not_sqlite = build_db_paths(DB_DIR, db_files, is_sqlite_file)\n",
|
|
" print_db_path_report(db_paths, missing, not_sqlite)\n",
|
|
"\n",
|
|
" # Now run and save one file per DB (no global aggregation)\n",
|
|
" \n",
|
|
" enable_observe = bool(cfg.get(\"enable_observe\", False))\n",
|
|
" app = build_graph(enable_observe) \n",
|
|
" \n",
|
|
" print(f\"enable_observe: {enable_observe}\")\n",
|
|
" print(f\"pii_targets: {PII_TARGETS}\")\n",
|
|
"\n",
|
|
" run_batch(db_paths, PII_TARGETS, PII_CONFIG, app, OUT_DIR)\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" main()\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.18"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|