diff --git a/evaluation/locomo/episodic_memory/generate_scores.py b/evaluation/locomo/episodic_memory/generate_scores.py index 36b559a04..fbfbf0d82 100644 --- a/evaluation/locomo/episodic_memory/generate_scores.py +++ b/evaluation/locomo/episodic_memory/generate_scores.py @@ -1,12 +1,14 @@ # This is adapted from Mem0 (https://github.com/mem0ai/mem0/blob/main/evaluation/generate_scores.py). # It has been modified to only report LLM judge scores. -import json +import json, sys import pandas as pd +path = sys.argv[1] + # Load the evaluation metrics data -with open("evaluation_metrics.json", "r") as f: +with open(path, "r") as f: data = json.load(f) # Flatten the data into a list of question items @@ -14,11 +16,49 @@ for key in data: all_items.extend(data[key]) +final_matrix = "" +num_use_kg = 0 +num_use_em = 0 +num_positive_use_em = 0 +num_positive_use_kg = 0 +num_correct = 0 +num_incorrect = 0 +num_total = 0 +for item in all_items: + if "final_matrix" in item: + final_matrix = item["final_matrix"] + + if item.get("used_kg", True): + num_use_kg += 1 + if item["llm_score"] == 1: + num_positive_use_kg += 1 + elif item.get("used_em", False): + num_use_em += 1 + if item["llm_score"] == 1: + num_positive_use_em += 1 + + if item["llm_score"] == 1: + num_correct += 1 + else: + num_incorrect += 1 + + num_total += 1 + +if num_use_kg != 0: + final_matrix += f"Positive cases using KG(EM search not sufficient): {num_positive_use_kg}/{num_use_kg} = {num_positive_use_kg/num_use_kg*100:.2f}%\n" +else: + final_matrix += "Using 0 KG searches.\n" + +if num_use_em != 0: + final_matrix += f"Positive cases using EM only: {num_positive_use_em}/{num_use_em} = {num_positive_use_em/num_use_em*100:.2f}%\n" +else: + final_matrix += "Using 0 EM searches.\n" + # Convert to DataFrame df = pd.DataFrame(all_items) # Convert category to numeric type -df["category"] = pd.to_numeric(df["category"]) +# df["category"] = pd.to_numeric(df["category"]) # Calculate mean scores by category result = df.groupby("category").agg({"llm_score": "mean"}).round(4) @@ -35,3 +75,8 @@ print("\nOverall Mean Scores:") print(overall_means) + +# print(f"\nNumber of positive cases using long-term memory: {num_positive_use_em}") +# print(f"Number of negative cases using long-term memory: {num_negative_use_em}") +# print(f"Total correct answers: {num_correct}/{num_total}") +print(f"\nFinal Info Matrix:\n{final_matrix}") \ No newline at end of file diff --git a/evaluation/locomo/episodic_memory/llm_judge.py b/evaluation/locomo/episodic_memory/llm_judge.py index e4e18b255..77b4d6787 100644 --- a/evaluation/locomo/episodic_memory/llm_judge.py +++ b/evaluation/locomo/episodic_memory/llm_judge.py @@ -42,7 +42,7 @@ def evaluate_llm_judge(question, gold_answer, generated_answer): """Evaluate the generated answer against the gold answer using an LLM judge.""" response = client.chat.completions.create( - model="gpt-4o-mini", + model="gpt-4.1-mini", messages=[ { "role": "user", diff --git a/evaluation/locomo/episodic_memory/locomo_config.yaml b/evaluation/locomo/episodic_memory/locomo_config.yaml index fe60c9eab..993b37c13 100644 --- a/evaluation/locomo/episodic_memory/locomo_config.yaml +++ b/evaluation/locomo/episodic_memory/locomo_config.yaml @@ -18,7 +18,7 @@ storage: host: localhost port: 7687 user: neo4j - password: + password: password force_exact_similarity_search: true embedder: @@ -34,11 +34,11 @@ reranker: reranker_ids: - id_ranker_id - bm_ranker_id - - ce_ranker_id + # - ce_ranker_id id_ranker_id: type: "identity" bm_ranker_id: type: "bm25" - ce_ranker_id: - type: "cross-encoder" - model_name: "cross-encoder/qnli-electra-base" + # ce_ranker_id: + # type: "cross-encoder" + # model_name: "cross-encoder/qnli-electra-base" diff --git a/evaluation/locomo/episodic_memory/locomo_evaluate.py b/evaluation/locomo/episodic_memory/locomo_evaluate.py index 10853df1d..a5847fea8 100644 --- a/evaluation/locomo/episodic_memory/locomo_evaluate.py +++ b/evaluation/locomo/episodic_memory/locomo_evaluate.py @@ -1,53 +1,112 @@ # This is adapted from Mem0 (https://github.com/mem0ai/mem0/blob/main/evaluation/evals.py). -# It is modified to only report LLM judge scores and to be simpler. +# It is modified to only report LLM judge scores. import argparse +import concurrent.futures import json +import threading +from collections import defaultdict from dotenv import load_dotenv from llm_judge import evaluate_llm_judge +from tqdm import tqdm + +load_dotenv() + + +def process_item(item_data): + k, v = item_data + local_results = defaultdict(list) + + for item in tqdm(v, desc=f"Processing {k} sample"): + question = str(item["question"]) + locomo_answer = str(item["locomo_answer"]) + response = str(item["model_answer"]) + category = str(item["category"]) + + # Skip category 5 + if category == "5": + continue + + llm_score = evaluate_llm_judge(question, locomo_answer, response) + + res = { + "question": question, + "answer": locomo_answer, + "response": response, + "category": category, + "llm_score": llm_score, + } + for key, val in item.items(): + if key not in [ + "question", + "locomo_answer", + "model_answer", + "category", + ]: + if type(val) is float: + # Round to 3 decimal places + val = round(val, 3) + res[key] = val + if res["llm_score"] == 0 and res.get("used_em", False): + res["em_negative"] = 1 + + local_results[k].append(res) + + return local_results def main(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="Evaluate results") parser.add_argument( - "--data-path", required=True, help="Path to the source data file" + "--input_file", + type=str, + default="results/rag_results_500_k1.json", + help="Path to the input dataset file", ) parser.add_argument( - "--target-path", required=True, help="Path to the target data file" + "--output_file", + type=str, + default="evaluation_metrics.json", + help="Path to save the evaluation results", ) + parser.add_argument( + "--max_workers", + type=int, + default=10, + help="Maximum number of worker threads", + ) + args = parser.parse_args() - data_path = args.data_path - target_path = args.target_path - # Load environment variables - load_dotenv() + with open(args.input_file, "r") as f: + data = json.load(f) - with open(data_path, "r") as f: - test_data = json.load(f) - results = {} - for key, value in test_data.items(): - if key == "5": - continue - local_result = [] - for item in value: - question = item["question"] - locomo_answer = f"{item['locomo_answer']}" - response = f"{item['model_answer']}" - llm_score = evaluate_llm_judge(question, locomo_answer, response) - local_result.append( - { - "question": question, - "answer": locomo_answer, - "response": response, - "category": key, - "llm_score": llm_score, - } - ) - results[key] = local_result - with open(target_path, "w") as f: - json.dump(results, f, indent=4) + results = defaultdict(list) + results_lock = threading.Lock() + + # Use ThreadPoolExecutor with specified workers + with concurrent.futures.ThreadPoolExecutor( + max_workers=args.max_workers + ) as executor: + futures = [ + executor.submit(process_item, item_data) for item_data in data.items() + ] + + for future in tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): + local_results = future.result() + with results_lock: + for k, items in local_results.items(): + results[k].extend(items) + + # Save results to JSON file + with open(args.output_file, "w") as f: + json.dump(results, f, indent=4) + + print(f"Results saved to {args.output_file}") if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/evaluation/locomo/episodic_memory/locomo_ingest.py b/evaluation/locomo/episodic_memory/locomo_ingest.py index 42b2dc8af..38d0857f7 100644 --- a/evaluation/locomo/episodic_memory/locomo_ingest.py +++ b/evaluation/locomo/episodic_memory/locomo_ingest.py @@ -1,9 +1,12 @@ import argparse import asyncio import json +import uuid from collections import deque -from datetime import datetime +from datetime import datetime, timedelta, timezone from typing import cast +import time +import os from dotenv import load_dotenv @@ -12,7 +15,128 @@ from memmachine.episodic_memory.episodic_memory_manager import ( EpisodicMemoryManager, ) +from memmachine.knowledge_graph.re_gpt4_1 import KnowledgeGraph +from memmachine.common.vector_graph_store import Node +def load_data( + start_line: int = 1, + num_cases: int = 100, + randomize: bool = True, +): + # dataset = "/home/tomz/MemMachine/evaluation/locomo/wiki-filter-gpt-4o-mini.json" + dataset = "/home/tomz/MemMachine/evaluation/locomo/wikimultihop.json" + print(f"Loading data from line {start_line} to {num_cases}, randomize={randomize}") + contexts = [] + questions = [] + answers = [] + types = [] + i = 1 + with open(dataset, "r", encoding="utf-8") as f: + for line in f: + if i < start_line: + i += 1 + continue + if i > num_cases: + break + + line = line.strip() + if not line: + continue + obj = json.loads(line) + obj["context"] = json.loads(obj["context"]) + c_list = [] + for key, sentences in obj["context"]: + for s in sentences: + c = f"{key}: {s}" + if randomize: + insert_index = random.randrange(len(contexts) + 1) # 0..len inclusive + c_list.insert(insert_index, c) + else: + c_list.append(c) + + contexts.append(c_list) + questions.append(obj["question"]) + answers.append(obj["answer"]) + types.append(obj["type"]) + i += 1 + return contexts, questions, answers, types + +async def ingest_wikimultihop(): + memory_manager = EpisodicMemoryManager.create_episodic_memory_manager( + "locomo_config.yaml" + ) + + memory = cast( + EpisodicMemory, + await memory_manager.get_episodic_memory_instance( + group_id="1", + session_id="1", + user_id=["user"], + ), + ) + + contexts, _, _, _ = load_data(start_line=1, num_cases=305, randomize=False) + + print("Loaded", len(contexts), "contexts, start ingestion...") + + num_batch = 50 + em_tasks = [] + episodes = [] + added_contexts = set() + num_added = 0 + t1 = datetime.now(timezone.utc) + for c_list in contexts: + for c in c_list: + if c not in added_contexts: + added_contexts.add(c) + num_added += 1 + # if num_added <= 7885: + # continue + + cur_uuid = uuid.uuid4() + ts = t1 + timedelta(seconds=len(added_contexts)) + episodes.append(Node( + uuid=cur_uuid, + labels={"Episode"}, + # Make timestamp different for each episode + properties={ + "content": c, + "timestamp": ts, + "session_id": 1 + }, + )) + + producer = c.split(":")[0] + ts_str = ts.strftime("%Y-%m-%d %H:%M:%S") + em_tasks.append(memory.add_memory_episode( + producer="user", + produced_for="user", + episode_content=c, + episode_type="default", + content_type=ContentType.STRING, + timestamp=ts_str, + metadata={ + "source_timestamp": ts_str, + "source_speaker": "user", + }, + uuid=cur_uuid, + ) + ) + + if len(added_contexts) % num_batch == 0 or (c_list == contexts[-1] and c == c_list[-1]): + t = time.perf_counter() + await add_episode_bulk(episodes) + print(f"Gathered and added {len(episodes)} episodes to KG in {(time.perf_counter() - t):.3f}s") + episodes = [] + + t = time.perf_counter() + await asyncio.gather(*em_tasks) + print(f"Added {len(em_tasks)} episodes to EM in {(time.perf_counter() - t):.3f}s") + em_tasks = [] + + print(f"Total added episodes: {len(added_contexts)}") + + print(f"Completed WIKI-Multihop ingestion, added {len(added_contexts)} episodes.") async def main(): parser = argparse.ArgumentParser() @@ -23,16 +147,62 @@ async def main(): data_path = args.data_path - with open(data_path, "r") as f: - locomo_data = json.load(f) - memory_manager = EpisodicMemoryManager.create_episodic_memory_manager( "locomo_config.yaml" ) + from openai import AsyncOpenAI + model = AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + timeout=600, + ) + + from neo4j import AsyncGraphDatabase + driver = AsyncGraphDatabase.driver( + "bolt://localhost:9999", + auth=( + "neo4j", + "password", + ), + connection_timeout=3600, # seconds + # how long to wait for a pooled connection + connection_acquisition_timeout=3600, # seconds + # built-in retry window for transient failures + max_transaction_retry_time=3600, + max_connection_lifetime=3600, + keep_alive=True, + ) + + from memmachine.common.vector_graph_store.neo4j_vector_graph_store import Neo4jVectorGraphStore, Neo4jVectorGraphStoreParams + store = Neo4jVectorGraphStore( + Neo4jVectorGraphStoreParams( + driver=driver, + max_concurrent_transactions=200, + force_exact_similarity_search=False, + ) + ) + + await store.create_fulltext_index() + + from memmachine.common.embedder.openai_embedder import OpenAIEmbedder + embedder = OpenAIEmbedder( + { + "model": "text-embedding-3-small", + "api_key": os.getenv("OPENAI_API_KEY"), + } + ) + + # await ingest_wikimultihop() + # return + + with open(data_path, "r") as f: + locomo_data = json.load(f) + async def process_conversation(idx, item, memory_manager: EpisodicMemoryManager): if "conversation" not in item: return + + nonlocal model, store, embedder conversation = item["conversation"] speaker_a = conversation["speaker_a"] @@ -44,6 +214,15 @@ async def process_conversation(idx, item, memory_manager: EpisodicMemoryManager) group_id = f"group_{idx}" + model_name="gpt-4.1-mini" + print(f"Creating Knowledge Graph with model {model_name}...") + kg = KnowledgeGraph( + model_name=model_name, + model=model, + embedder=embedder, + store=store, + ) + memory = cast( EpisodicMemory, await memory_manager.get_episodic_memory_instance( @@ -53,7 +232,10 @@ async def process_conversation(idx, item, memory_manager: EpisodicMemoryManager) ), ) + num_batch = 50 + kg_batch = [] session_idx = 0 + num_msg = 0 while True: session_idx += 1 session_id = f"session_{session_idx}" @@ -73,6 +255,9 @@ async def process_conversation(idx, item, memory_manager: EpisodicMemoryManager) context_messages.append( f"[{session_date_time}] {speaker}: {message_text}" ) + + id = uuid.uuid4() + ts = datetime.now() await memory.add_memory_episode( producer=speaker, @@ -80,14 +265,46 @@ async def process_conversation(idx, item, memory_manager: EpisodicMemoryManager) episode_content=message_text, episode_type="default", content_type=ContentType.STRING, - timestamp=datetime.now(), + timestamp=ts, metadata={ "source_timestamp": session_date_time, "source_speaker": speaker, "blip_caption": blip_caption, }, + uuid=id, ) + fmt = "%I:%M %p on %d %B, %Y" + ts = datetime.strptime(session_date_time, fmt) + kg_content = f"{speaker}: {message_text}" + if blip_caption: + kg_content += f" [Image Caption: {blip_caption}]" + kg_batch.append(Node( + uuid=id, + labels={"Episode"}, + # Make timestamp different for each episode + properties={ + "content": kg_content, + "timestamp": ts + timedelta(seconds=num_msg), + "session_id": memory.group_id(), + }, + )) + num_msg += 1 + + if len(kg_batch) >= num_batch: + t = time.perf_counter() + await kg.add_episode_bulk(kg_batch) + print(f"Added batch of {len(kg_batch)} episodes to KG in {(time.perf_counter() - t):.3f}s") + kg_batch = [] + if len(kg_batch) > 0: + t = time.perf_counter() + await kg.add_episode_bulk(kg_batch, True) + print(f"Added final batch of {len(kg_batch)} episodes to KG in {(time.perf_counter() - t):.3f}s") + kg_batch = [] + try: + kg.print_ingest_perf_matrix() + except Exception as e: + print(f"Error printing KG ingest perf matrix: {e}") await memory.close() tasks = [ diff --git a/evaluation/locomo/episodic_memory/locomo_search.py b/evaluation/locomo/episodic_memory/locomo_search.py index 7f045a8a4..3c330c650 100644 --- a/evaluation/locomo/episodic_memory/locomo_search.py +++ b/evaluation/locomo/episodic_memory/locomo_search.py @@ -3,6 +3,7 @@ import json import os import time +from datetime import datetime, timedelta from typing import Any, cast from dotenv import load_dotenv @@ -12,34 +13,75 @@ from memmachine.episodic_memory.episodic_memory_manager import ( EpisodicMemoryManager, ) +from memmachine.knowledge_graph.re_gpt4_1 import KnowledgeGraph +from memmachine.common.vector_graph_store import Node # This is adapted from Mem0 (https://github.com/mem0ai/mem0/blob/main/evaluation/prompts.py). # It is modified to work with MemMachine. ANSWER_PROMPT = """ - You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories to answer a question. +You are an analytical AI that reasons deeply about context before answering questions. Your task is to: + +1. FIRST: Look for direct, explicit answers in the context +2. ANALYZE the context thoroughly for relevant information +3. IDENTIFY patterns, connections, and implications +4. REASON about what the context suggests or implies +5. ANSWER based on direct evidence OR analysis + + +- Scan through ALL episodes and facts completely before answering +- Look for every explicit statement that relates to the question +- NEVER stop after finding the first answer - continue scanning for more +- When asking "what did X show Y", look for ALL items X showed Y on that date +- Collect multiple items, events, or details that answer the same question +- If not found directly, identify all context elements related to the question +- Look for patterns, themes, and implicit information in the context +- Consider what the context suggests beyond explicit statements +- Note any contradictions or missing information that affects the answer +- Pay close attention to temporal information and dates (validAt timestamps) +- For time-sensitive questions, prioritize more recent information +- Consider the chronological sequence of events when relevant +- CRITICAL: Ensure completeness by including ALL relevant items found +- If you find 2+ items for the same question, mention them all in your answer +- Be precise with details (specific types, colors, descriptions when available) +- Draw logical conclusions based on available evidence +- Don't give reasoning in the output + + +**Output Format** (JSON dict, don't give the JSON with ```json): +{"answer" : "Your direct, short(max 2 sentences) answer based on your analysis"} +""" + +USER_PROMPT = """ +${context} + + + +Question: ${question} + +""" + +WIKI_ANSWER_PROMPT=""" + You are an intelligent memory assistant tasked with retrieving accurate information from context memories to answer a question. # CONTEXT: - You have access to memories from a conversation. These memories contain - timestamped information that may be relevant to answering the question. + You have access to memories of contexts. These memories contain + information that may be relevant to answering the question. # INSTRUCTIONS: - 1. Carefully analyze all provided memories from both speakers - 2. Pay special attention to the timestamps to determine the answer - 3. If the question asks about a specific event or fact, look for direct evidence in the memories - 4. If the memories contain contradictory information, prioritize the most recent memory - 5. If there is a question about time references (like "last year", "two months ago", etc.), - calculate the actual date based on the memory timestamp. For example, if a memory from + 1. Carefully analyze all provided memories + 2. If the question asks about a specific event or fact, look for direct evidence in the memories + 3. If the memories contain contradictory information, prioritize the most recent memory + 4. If there is a question about time references (like "last year", "two months ago", etc.), + calculate the actual date based on the context. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years. For example, + 5. Always convert relative time references to specific dates, months, or years. For example, convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory timestamp. Ignore the reference while answering the question. - 7. Focus only on the content of the memories from both speakers. Do not confuse character - names mentioned in memories with the speakers. - 8. The answer should be less than 5-6 words. + 6. The answer should be less than 5-6 words. # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question - 2. Examine the timestamps and content of these memories carefully + 1. First, examine all memories that contain information related to the question, pay attention to FACTS section if available + 2. Examine the content of these memories carefully 3. Look for explicit mentions of dates, times, locations, or events that answer the question 4. If the answer requires calculation (e.g., converting relative time references), show your work 5. Formulate a precise, concise answer based solely on the evidence in the memories @@ -58,13 +100,59 @@ """ -def format_memory(episodes, summary) -> str: +def format_memory(episodes, summary, kg_episodes, fmt, include_timestamp: bool = True) -> str: + kg_final = [] + uuids = set() + + episode_nodes = [] + for e in episodes: + episode_nodes.append(Node( + uuid=e.uuid, + properties={ + "timestamp": datetime.strptime(e.user_metadata['source_timestamp'], fmt), + "content": f"{e.user_metadata['source_speaker']}: {e.content}{f' [Image Caption: {e.user_metadata["blip_caption"]}]' if e.user_metadata.get('blip_caption') else ''}" + }, + )) + uuids.add(e.uuid) + + num_filtered = 0 + for e in kg_episodes: + if e.uuid in uuids: + num_filtered += 1 + continue + # Drop microseconds and timezone info for consistency + ts_str = e.properties['timestamp'].strftime("%Y-%m-%d %H:%M:%S") + episode_nodes.append(Node( + uuid=e.uuid, + properties={ + "timestamp": datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S"), + "content": e.properties['content'] + }, + )) + # print(f"Filtered out {num_filtered} duplicate episodes from KG results") + + final = sorted(episode_nodes, key=lambda x: x.properties["timestamp"]) + episode_context = ( + # "\n" + # + "\n".join( + # [ + # f"[{episode.properties['timestamp']}] {episode.properties['content']}" + # for episode in kg_episodes + # ] + # ) + # + "\n".join( + # [ + # f"[{episode.user_metadata['source_timestamp']}] {episode.user_metadata['source_speaker']}: {episode.content}{f' [ATTACHED: {episode.user_metadata["blip_caption"]}]' if episode.user_metadata.get('blip_caption') else ''}" + # for episode in episodes + # ] + # ) + # + "\n\n" "\n" + "\n".join( [ - f"[{episode.user_metadata['source_timestamp']}] {episode.user_metadata['source_speaker']}: {episode.content}{f' [ATTACHED: {episode.user_metadata["blip_caption"]}]' if episode.user_metadata.get('blip_caption') else ''}" - for episode in episodes + f"[{episode.properties['timestamp']}] {episode.properties['content']}" if include_timestamp else f"{episode.properties['content']}" + for episode in final ] ) + "\n" @@ -74,8 +162,32 @@ def format_memory(episodes, summary) -> str: if summary else "" ) - return episode_context + "\n" + summary_context + # return episode_context + "\n" + summary_context + return episode_context + +def get_bedrock_reranker(): + import boto3 + from memmachine.common.reranker.amazon_bedrock_reranker import ( + AmazonBedrockReranker, + AmazonBedrockRerankerParams, + ) + region = "us-west-2" + + client = boto3.client( + "bedrock-agent-runtime", + region_name=region, + aws_access_key_id="key_id", + aws_secret_access_key="key", + ) + + return AmazonBedrockReranker( + AmazonBedrockRerankerParams( + client=client, + region=region, + model_id="amazon.rerank-v1:0" + ) + ) async def process_question( memory_manager: EpisodicMemoryManager, @@ -87,6 +199,7 @@ async def process_question( category, evidence, adversarial_answer, + limit, ): memory_start = time.time() memory = cast( @@ -102,46 +215,373 @@ async def process_question( short_term_episodes, long_term_episodes, summaries, - ) = await memory.query_memory(query=question, limit=30) + kg_episodes, + ) = await memory.query_memory(query=question, limit=limit) + episodes = long_term_episodes + short_term_episodes summary = summaries[0] if summaries else "" memory_end = time.time() - formatted_context = format_memory(episodes, summary) - prompt = ANSWER_PROMPT.format( - conversation_memories=formatted_context, question=question - ) + return episodes, summary + + # llm_start = time.time() + # rsp = await model.responses.create( + # model="gpt-4o-mini", + # max_output_tokens=4096, + # temperature=0.0, + # top_p=1, + # input=[{"role": "user", "content": prompt}], + # ) + # llm_end = time.time() + + # rsp_text = rsp.output_text + + # print_info = ( + # f"Question: {question}\n" + # f"Answer: {answer}\n" + # f"Response: {rsp_text}\n" + # f"Memory retrieval time: {memory_end - memory_start:.2f} seconds\n" + # f"LLM response time: {llm_end - llm_start:.2f} seconds\n" + # f"MEMORIES START\n{formatted_context}MEMORIES END\n" + # ) + + # return { + # "question": question, + # "locomo_answer": answer, + # "model_answer": rsp_text, + # "category": category, + # "evidence": evidence, + # "adversarial_answer": adversarial_answer, + # "conversation_memories": formatted_context, + # "print_info": print_info, + # } + +async def get_model_answer( + model: AsyncOpenAI, + group_id, + user, + qa, + formatted_context, + answer_prompt=USER_PROMPT, + perf_matrix={}, +): + question = qa["question"] + answer = qa.get("answer", "") + category = qa["category"] + evidence = qa.get("evidence", "") + adversarial_answer = qa.get("adversarial_answer", "") + + prompt = answer_prompt.format( + context=formatted_context, question=question + ) + + llm_start = time.time() rsp = await model.responses.create( - model="gpt-4o-mini", + model="gpt-4.1-mini", max_output_tokens=4096, temperature=0.0, top_p=1, - input=[{"role": "user", "content": prompt}], + input=[ + {"role": "system", "content": ANSWER_PROMPT}, + {"role": "user", "content": prompt} + ], ) llm_end = time.time() - rsp_text = rsp.output_text + # Remove leading \n and trailing \n from rsp.output_text + output_text = "" + for line in rsp.output_text.split("\n"): + if line == "```json" or line == "```": + continue + output_text += line + "\n" + + rsp_dict = {} + try: + if output_text.startswith("{"): + rsp_dict = json.loads(output_text) + else: + rsp_dict = {"answer": output_text} + print(f"WARNING: LLM response is not JSON:\n{rsp.output_text}\nUsing the string directly:\n{output_text}") + except Exception as e: + print(f"Parse LLM response\n:{output_text}\ngot error: {e}") + raise e + rsp_text = rsp_dict["answer"] - print( + print_info = ( f"Question: {question}\n" f"Answer: {answer}\n" f"Response: {rsp_text}\n" - f"Memory retrieval time: {memory_end - memory_start:.2f} seconds\n" + # f"Memory retrieval time: {memory_end - memory_start:.2f} seconds\n" f"LLM response time: {llm_end - llm_start:.2f} seconds\n" f"MEMORIES START\n{formatted_context}MEMORIES END\n" ) - return { + + question_response = { "question": question, "locomo_answer": answer, "model_answer": rsp_text, "category": category, "evidence": evidence, - "adversarial_answer": adversarial_answer, - "conversation_memories": formatted_context, + "evidence_text": qa["evidence_text"], } + for key, value in perf_matrix.items(): + question_response[key] = value + + question_response["conversation_memories"] = formatted_context + + return ( + category, + question_response, + ) + +def load_data( + start_line: int = 1, + num_cases: int = 100, + randomize: bool = True, +): + # dataset = "/home/tomz/MemMachine/evaluation/locomo/wikimultihop.json" + # dataset = "/home/tomz/MemMachine/evaluation/locomo/wiki-filter-gpt-4o-mini.json" + dataset = "/home/tomz/MemMachine/evaluation/locomo/wiki-filter-gpt-4.1.json" + print(f"Loading data from {dataset} line {start_line} to {num_cases}, randomize={randomize}") + contexts = [] + supporting_facts = [] + questions = [] + answers = [] + types = [] + i = 1 + with open(dataset, "r", encoding="utf-8") as f: + for line in f: + if i < start_line: + i += 1 + continue + if i > num_cases: + break + + line = line.strip() + if not line: + continue + obj = json.loads(line) + obj["context"] = json.loads(obj["context"]) + c_list = [] + key_to_sentences = {} + for key, sentences in obj["context"]: + key_to_sentences[key] = sentences + for s in sentences: + c = f"{key}: {s}" + if randomize: + insert_index = random.randrange(len(contexts) + 1) # 0..len inclusive + c_list.insert(insert_index, c) + else: + c_list.append(c) + contexts.append(c_list) + questions.append(obj["question"]) + answers.append(obj["answer"]) + types.append(obj["type"]) + golden_facts = json.loads(obj["supporting_facts"]) + fact_sents = [] + for fact in golden_facts: + key = fact[0] + sentence_idx = int(fact[1]) + fact_sents.append(key_to_sentences[key][sentence_idx]) + supporting_facts.append(fact_sents) + + i += 1 + return contexts, questions, answers, types, supporting_facts + +async def search_wikimultihop( + target_path: str, + limit: int, +): + memory_manager = EpisodicMemoryManager.create_episodic_memory_manager( + "locomo_config.yaml" + ) + + model = AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + timeout=600, + ) + + reranker = get_bedrock_reranker() + + em_total_time = 0.0 + kg_total_time = 0.0 + answer_total_time = 0.0 + + from neo4j import AsyncGraphDatabase + driver = AsyncGraphDatabase.driver( + "bolt://localhost:9999", + auth=( + "neo4j", + "password", + ), + connection_timeout=3600, + connection_acquisition_timeout=3600, + max_transaction_retry_time=3600, + max_connection_lifetime=3600, + keep_alive=True, + ) + + from memmachine.common.vector_graph_store.neo4j_vector_graph_store import Neo4jVectorGraphStore, Neo4jVectorGraphStoreParams + store = Neo4jVectorGraphStore( + Neo4jVectorGraphStoreParams( + driver=driver, + max_concurrent_transactions=200, + force_exact_similarity_search=False, + ) + ) + + await store.create_fulltext_index() + + from memmachine.common.embedder.openai_embedder import OpenAIEmbedder + embedder = OpenAIEmbedder( + { + "model": "text-embedding-3-small", + "api_key": os.getenv("OPENAI_API_KEY"), + } + ) + + contexts, questions, answers, types, supporting_facts = load_data(start_line=1, num_cases=250, randomize=False) + + import uuid + num_batch = 30 + kg_tasks = [] + em_tasks = [] + kg_episode_list = [] + triple_list = [] + l_episode_list = [] + summaries = [] + total_facts = 0 + hit_facts = 0 + ts = datetime.now() + for c_list, q, a, t in zip(contexts, questions, answers, types): + # ep_list = [] + # for c in c_list: + # ep_list.append( + # Node( + # uuid=uuid.uuid4(), + # properties={ + # "timestamp": ts + timedelta(seconds=len(ep_list)), + # "content": c, + # } + # ) + # ) + # kg_episode_list.append(ep_list) + # triple_list.append([]) + # l_episode_list.append([]) + # summaries.append("") + # ====== baseline above ====== + + kg_tasks.append( + search_kg(reranker, model, driver, store, embedder, q, 1, limit=limit) + ) + + em_tasks.append( + process_question( + memory_manager, + model, + "1", + "user", + q, + a, + t, + "", + "", + limit=limit, + ) + ) + + if len(kg_tasks) >= num_batch or (q == questions[-1]): + ts = time.perf_counter() + print(f"Async gathering {len(kg_tasks)} KG tasks...") + r_kg = await asyncio.gather(*kg_tasks) + # r_kg = [([], []) for _ in kg_tasks] # IGNORE KG + for t_list, e_list in r_kg: + kg_episode_list.append(e_list) + triple_list.append(t_list) + print(f"Gathered {len(kg_tasks)} KG tasks in {time.perf_counter() - ts:.2f}s") + kg_total_time += time.perf_counter() - ts + kg_tasks = [] + + ts = time.perf_counter() + print(f"Async gathering {len(em_tasks)} EM tasks...") + # r_em = await asyncio.gather(*em_tasks) + r_em = [([], "") for _ in em_tasks] # IGNORE EM + for episodes, summary in r_em: + l_episode_list.append(episodes) + summaries.append(summary) + print(f"Gathered {len(em_tasks)} EM tasks in {time.perf_counter() - ts:.2f}s") + em_total_time += time.perf_counter() - ts + em_tasks = [] + + print(f"Processed {len(kg_episode_list) if len(kg_episode_list) != 0 else len(l_episode_list)} / {len(questions)} questions.") + + # Run all response generation in parallel for current conversation + r_batch = 50 + r_tasks = [] + responses = [] + for ( + l_episodes, + summary, + triple_texts, + kg_episodes, + q, + a, + t, + facts + ) in zip(l_episode_list, + summaries, + triple_list, + kg_episode_list, + questions, + answers, + types, + supporting_facts): + formatted_context = format_memory(l_episodes, summary, triple_texts, kg_episodes, fmt="%Y-%m-%d %H:%M:%S", include_timestamp=False) + qa = { + "question": q, + "answer": a, + "category": t, + } + + total_facts += len(facts) + for fact in facts: + if fact in formatted_context: + hit_facts += 1 + + r_tasks.append( + get_model_answer( + model, + "1", + "1", + qa, + formatted_context, + answer_prompt=WIKI_ANSWER_PROMPT, + ) + ) + if len(r_tasks) >= r_batch or (q == questions[-1]): + ts = time.perf_counter() + print(f"Async gathering {len(r_tasks)} response tasks...") + responses.extend(await asyncio.gather(*r_tasks)) + print(f"Gathered {len(r_tasks)} response tasks in {time.perf_counter() - ts:.2f}s, {len(responses)}/{len(questions)} total.") + answer_total_time += time.perf_counter() - ts + r_tasks = [] + + results: dict[str, Any] = {} + for category, response in responses: + # print(f"---\n{response["print_info"][:300]} ---\n") + category_result = results.get(category, []) + category_result.append(response) + results[category] = category_result + + print(f"Total Episodic Memory retrieval time: {em_total_time:.2f} seconds") + print(f"Total Knowledge Graph retrieval time: {kg_total_time:.2f} seconds") + print(f"Total Answer generation time: {answer_total_time:.2f} seconds") + print(f"Recall of golden sentences: {hit_facts}/{total_facts} = {hit_facts/total_facts*100:.2f}%") + with open(target_path, "a") as f: + json.dump(results, f, indent=4) async def main(): parser = argparse.ArgumentParser() @@ -152,14 +592,18 @@ async def main(): parser.add_argument( "--target-path", required=True, help="Path to the target data file" ) + parser.add_argument( + "--limit", required=False, help="Path to the target data file" + ) args = parser.parse_args() data_path = args.data_path target_path = args.target_path + limit = int(args.limit) if args.limit else 30 - with open(data_path, "r") as f: - locomo_data = json.load(f) + # await search_wikimultihop(target_path, limit) + # return memory_manager = EpisodicMemoryManager.create_episodic_memory_manager( "locomo_config.yaml" @@ -167,23 +611,165 @@ async def main(): model = AsyncOpenAI( api_key=os.getenv("OPENAI_API_KEY"), + timeout=600, + ) + + em_total_time = 0.0 + kg_total_time = 0.0 + answer_total_time = 0.0 + total_input_tolens = 0 + total_output_tokens = 0 + total_num_sufficiency_checks = 0 + + from neo4j import AsyncGraphDatabase + driver = AsyncGraphDatabase.driver( + "bolt://localhost:9999", + auth=( + "neo4j", + "password", + ), + connection_timeout=3600, # seconds + # how long to wait for a pooled connection + connection_acquisition_timeout=3600, # seconds + # built-in retry window for transient failures + max_transaction_retry_time=3600, + max_connection_lifetime=3600, + keep_alive=True, + ) + + from memmachine.common.vector_graph_store.neo4j_vector_graph_store import Neo4jVectorGraphStore, Neo4jVectorGraphStoreParams + store = Neo4jVectorGraphStore( + Neo4jVectorGraphStoreParams( + driver=driver, + max_concurrent_transactions=200, + force_exact_similarity_search=False, + ) + ) + + await store.create_fulltext_index() + + from memmachine.common.embedder.openai_embedder import OpenAIEmbedder + embedder = OpenAIEmbedder( + { + "model": "text-embedding-3-small", + "api_key": os.getenv("OPENAI_API_KEY"), + } + ) + + reranker = get_bedrock_reranker() + + model_name = "gpt-4.1-mini" + print(f"Using KnowledgeGraph with model {model_name}...") + kg = KnowledgeGraph( + model_name=model_name, + model=model, + embedder=embedder, + store=store, + reranker=reranker, ) + with open(data_path, "r") as f: + locomo_data = json.load(f) + + recall_hit = 0 + num_facts = 0 + num_episodes_returned = 0 + num_used_kg = 0 + num_used_longterm_only = 0 + num_used_both = 0 + num_questions = 0 + num_processed = 0 + skip_to_index = 0 + run_until_index = 20 results: dict[str, Any] = {} for idx, item in enumerate(locomo_data): if "conversation" not in item: continue + + if num_processed < skip_to_index: + num_processed += 1 + continue + + print(f"Processing questions for group {idx}...") conversation = item["conversation"] user = conversation["speaker_a"] - qa_list = item["qa"] + break_session = 0 + break_sentence = 0 + full_text = "" + evidence_to_text = {} + session_idx = 0 + num_msg = 0 + while True: + session_idx += 1 + session_id = f"session_{session_idx}" + if session_id not in conversation: + break - print(f"Processing questions for group {idx}...") + session = conversation[session_id] + session_date_time = conversation[f"{session_id}_date_time"] + + for message in session: + num_msg += 1 + speaker = message["speaker"] + dia_id = message["dia_id"] + text = message["text"] + img_url = message.get("img_url") + blip_caption = message.get("blip_caption") + + text = f"{speaker}: {text}" + # if img_url: + # text += f" [Image URL: {img_url}]\n" + # if blip_caption: + # text += f" [Image Caption: {blip_caption}]\n" + + evidence_to_text[dia_id] = text + full_text += text + if num_msg == 300: + session, sentence = dia_id.split(":") + break_session = int(session[1:]) + break_sentence = int(sentence) + # print(f"Breaking at dia_id {dia_id}, session {break_session}, sentence {break_sentence}") + break + + qas = item["qa"] + qa_list = [] + for qa in qas: + # if qa["category"] != 3: + # continue + if qa["category"] == 5: + continue + ev_ids = [] + for ev in qa["evidence"]: + if "," in ev: + ids = ev.split(",") + ev_ids.extend(ids) + elif ";" in ev: + ids = ev.split(";") + ev_ids.extend(ids) + else: + ev_ids.append(ev) + + if len(ev_ids) == 0: + continue + qa_list.append(qa) + + print(f"Testing on {len(qa_list)} questions.") group_id = f"group_{idx}" - async def respond_question(qa): + # qa_list = [{ + # "question": "When did Melanie go camping in June?", + # "answer": "The week before 27 June 2023", + # "category": "2", + # "evidence": ["D4:6"], + # }] + # group_id = "group_0" + + # qa_list = qa_list[50:55] + + async def respond_question(qa, limit): question = qa["question"] answer = qa.get("answer", "") category = qa["category"] @@ -191,7 +777,7 @@ async def respond_question(qa): adversarial_answer = qa.get("adversarial_answer", "") - question_response = await process_question( + return await process_question( memory_manager, model, group_id, @@ -201,22 +787,254 @@ async def respond_question(qa): category, evidence, adversarial_answer, - ) - return ( - category, - question_response, + limit=limit, ) responses = [] + em_converted_list = [] + summaries = [] + contexts_tasks = [] + em_batch = 50 + # Get EM responses first for qa in qa_list: - responses.append(await respond_question(qa)) + # Always limit higher, rerank and truncate later + contexts_tasks.append(respond_question(qa, limit*3)) + if len(contexts_tasks) >= em_batch or qa['question'] == qa_list[-1]['question']: + ts = time.perf_counter() + print(f"Async gathering {len(contexts_tasks)} contexts tasks...") + r = await asyncio.gather(*contexts_tasks) + # r = [([], "") for _ in contexts_tasks] # IGNORE EM + for episodes, summary in r: + em_convert = [] + # Convert EM episodes to KG episodes + for e in episodes: + em_convert.append(Node( + uuid=e.uuid, + properties={ + "timestamp": datetime.strptime(e.user_metadata['source_timestamp'], "%I:%M %p on %d %B, %Y"), + "content": f"{e.user_metadata['source_speaker']}: {e.content}{f' [Image Caption: {e.user_metadata["blip_caption"]}]' if e.user_metadata.get('blip_caption') else ''}" + }, + )) + em_converted_list.append(em_convert) + summaries.append(summary) + print(f"Gathered {len(contexts_tasks)} contexts tasks in {time.perf_counter() - ts:.2f}s") + contexts_tasks = [] + em_total_time += time.perf_counter() - ts + + # # Read em_converted_list from binary file to reuse + # import pickle + # with open(f"em_converted_group_{idx}.pkl", "rb") as f: + # em_converted_list = pickle.load(f) + + suff_tasks = [] + for em_convert, qa in zip(em_converted_list, qa_list): + suff_tasks.append(kg.check_sufficiency_batch(em_convert, [], qa["question"])) + total_num_sufficiency_checks += len(em_convert) // limit + (1 if len(em_convert) % limit != 0 else 0) + + print(f"Async gathering {len(suff_tasks)} sufficiency check tasks for EM...") + t = time.perf_counter() + r_suff = await asyncio.gather(*suff_tasks) + print(f"Gathered {len(suff_tasks)} sufficiency check tasks in {time.perf_counter() - t:.2f}s") + + # # Dump em_converted_list as binary file for reuse + # import pickle + # with open(f"em_converted_group_{idx}.pkl", "wb") as f: + # pickle.dump(em_converted_list, f)\ + + # # Dump r_suff as binary file for reuse + # import pickle + # with open(f"r_suff_group_{idx}.pkl", "wb") as f: + # pickle.dump(r_suff, f) + + + # # USE EM ONLY + # r_suff = [(True, [], [], "", 0, 0) for em_convert in em_converted_list] + + # Get KG response only if EM is insufficient + kg_batch = 50 + kg_tasks = [] + cur_real_kg_task = 0 + perf_list = [] + result_episodes_list = [] + em_suff_list = [] + kg_suff_list = [] + for (is_em_sufficient, sorted_suff_em_episodes, possible_relevant, reasoning_str, itoken, otoken), qa, em_convert in zip(r_suff, qa_list, em_converted_list): + total_input_tolens += itoken + total_output_tokens += otoken + em_suff_list.append(is_em_sufficient) + + # if EM == True, skip KG and return EM + # if EM == False, search KG and return filtered EM + KG + if is_em_sufficient: + res_e_list = em_convert + if len(em_convert) > limit: + cohere_res = await kg.cohere_rerank(em_convert, score_threshold=0.0, query=qa["question"], limit=limit) + res_e_list = [e for e, _ in cohere_res] + kg_tasks.append( + asyncio.sleep(0, result=(res_e_list, {"used_em": True, "used_kg": False, "reasoning": reasoning_str}, False)) + ) + else: + possible = sorted_suff_em_episodes + list(possible_relevant) + kg_tasks.append( + kg.search(query=qa["question"], possible_episodes=possible, session_id=group_id, limit=limit) + ) + # if len(em_convert) > limit: + # cohere_res = await kg.cohere_rerank(em_convert, score_threshold=0.0, query=qa["question"], limit=limit) + # res_e_list = [e for e, _ in cohere_res] + # kg_tasks.append( + # asyncio.sleep(0, result=(res_e_list, {"reasoning": reasoning_str}, False)) + # ) + num_used_kg += 1 + cur_real_kg_task += 1 + + if cur_real_kg_task >= kg_batch or qa['question'] == qa_list[-1]['question']: + print(f"Async gathering {cur_real_kg_task} KG tasks...") + ts = time.perf_counter() + r = await asyncio.gather(*kg_tasks) + kg_total_time += time.perf_counter() - ts + print(f"Gathered {cur_real_kg_task} KG tasks in {time.perf_counter() - ts:.2f}s") + + for e_list, perf_matrix, is_kg_sufficient in r: + result_episodes_list.append(e_list) + perf_list.append(perf_matrix) + kg_suff_list.append(is_kg_sufficient) + total_input_tolens += perf_matrix.get("num_llm_input_tokens", 0) + total_output_tokens += perf_matrix.get("num_llm_output_tokens", 0) + total_num_sufficiency_checks += perf_matrix.get("num_sufficiency_checks", 0) + kg_tasks = [] + cur_real_kg_task = 0 + + # Run all response generation in parallel for current conversation + r_tasks = [] + for qa, result_episodes_kg_formatted, perf_matrix in zip(qa_list, result_episodes_list, perf_list): + num_questions += 1 + evidence_strs = [] + + num_episodes_returned += len(result_episodes_kg_formatted) + formatted_context = format_memory([], None, result_episodes_kg_formatted, fmt="%Y-%m-%d %H:%M:%S") + + num_cur_facts = 0 + num_cur_hits = 0 + for ev in qa["evidence"]: + if "," in ev: + ids = ev.split(",") + for id in ids: + num_cur_facts += 1 + evidence_strs.append(evidence_to_text.get(id.strip(), "")) + if evidence_strs[-1] in formatted_context: + num_cur_hits += 1 + elif ";" in ev: + ids = ev.split(";") + for id in ids: + num_cur_facts += 1 + evidence_strs.append(evidence_to_text.get(id.strip(), "")) + if evidence_strs[-1] in formatted_context: + num_cur_hits += 1 + else: + num_cur_facts += 1 + evidence_strs.append(evidence_to_text.get(ev.strip(), "")) + if evidence_strs[-1] in formatted_context: + num_cur_hits += 1 + num_facts += num_cur_facts + recall_hit += num_cur_hits + perf_matrix["recall"] = f"{num_cur_hits}/{num_cur_facts} = {num_cur_hits/num_cur_facts*100:.2f}%" if num_cur_facts > 0 else "N/A" + + qa["evidence_text"] = evidence_strs + r_tasks.append( + get_model_answer( + model, + group_id, + user, + qa, + formatted_context, + USER_PROMPT, + perf_matrix, + ) + ) + + # # Baseline: use ground-truth evidence only + # r_tasks = [] + # for qa in qa_list: + # context = f"Evidence:\n" + # for ev in qa["evidence"]: + # if "," in ev: + # ids = ev.split(",") + # for id in ids: + # context += evidence_to_text.get(id.strip(), "") + "\n" + # elif ";" in ev: + # ids = ev.split(";") + # for id in ids: + # context += evidence_to_text.get(id.strip(), "") + "\n" + # else: + # context += evidence_to_text.get(ev.strip(), "") + "\n" + # context += "\nFull Conversation:\n" + full_text + + # r_tasks.append( + # get_model_answer( + # model, + # group_id, + # user, + # qa, + # context, + # ) + # ) + + print(f"Async gathering {len(r_tasks)} response tasks...") + ts = time.perf_counter() + responses.extend(await asyncio.gather(*r_tasks)) + print(f"Gathered {len(r_tasks)} response tasks in {time.perf_counter() - ts:.2f}s") + answer_total_time += time.perf_counter() - ts for category, response in responses: + # print(f"---\n{response["print_info"][:300]} ---\n") category_result = results.get(category, []) category_result.append(response) results[category] = category_result + + if num_processed >= run_until_index: + break + num_processed += 1 + # break + + final_matrix = f"""Total Episodic Memory retrieval time: {em_total_time:.2f} seconds +Total Knowledge Graph retrieval time: {kg_total_time:.2f} seconds +Average question response time: {(em_total_time + kg_total_time) / num_questions:.2f} seconds +Total Answer generation time: {answer_total_time:.2f} seconds +Total LLM input tokens: {total_input_tolens} +Average LLM input tokens per question: {total_input_tolens}/{num_questions} = {total_input_tolens/num_questions:.2f} +Total LLM output tokens: {total_output_tokens} +Average LLM output tokens per question: {total_output_tokens}/{num_questions} = {total_output_tokens/num_questions:.2f} +Total LLM tokens: {total_input_tolens + total_output_tokens} +Average LLM tokens per question: {total_input_tolens + total_output_tokens}/{num_questions} = {(total_input_tolens + total_output_tokens)/num_questions:.2f} +Total number of sufficiency checks: {total_num_sufficiency_checks} +Average number of sufficiency checks per question: {total_num_sufficiency_checks}/{num_questions} = {total_num_sufficiency_checks/num_questions:.2f} +Overall Evidence Recall: {recall_hit}/{num_facts} = {recall_hit/num_facts*100:.2f}% +Overall Evidence Precision: {recall_hit}/{num_episodes_returned} = {recall_hit/num_episodes_returned*100:.2f}% +Average episodes returned per question: {num_episodes_returned}/{num_questions} = {num_episodes_returned/num_questions:.2f} +Number of questions used KG: {num_used_kg}/{num_questions} = {num_used_kg/num_questions*100:.2f}% +""" + for cat, res_list in results.items(): + res_list[0]["final_matrix"] = final_matrix + break - with open(target_path, "w") as f: + # print(f"Total Episodic Memory retrieval time: {em_total_time:.2f} seconds") + # print(f"Total Knowledge Graph retrieval time: {kg_total_time:.2f} seconds") + # print(f"Total Answer generation time: {answer_total_time:.2f} seconds") + # print(f"Total LLM input tokens: {total_input_tolens}") + # print(f"Average LLM input tokens per question: {total_input_tolens}/{num_questions} = {total_input_tolens/num_questions:.2f}") + # print(f"Total LLM output tokens: {total_output_tokens}") + # print(f"Average LLM output tokens per question: {total_output_tokens}/{num_questions} = {total_output_tokens/num_questions:.2f}") + # print(f"Total LLM tokens: {total_input_tolens + total_output_tokens}") + # print(f"Average LLM tokens per question: {total_input_tolens + total_output_tokens}/{num_questions} = {(total_input_tolens + total_output_tokens)/num_questions:.2f}") + # print(f"Total number of sufficiency checks: {total_num_sufficiency_checks}") + # print(f"Average number of sufficiency checks per question: {total_num_sufficiency_checks}/{num_questions} = {total_num_sufficiency_checks/num_questions:.2f}") + # print(f"Overall Evidence Recall: {recall_hit}/{num_facts} = {recall_hit/num_facts*100:.2f}%") + # print(f"Overall Evidence Precision: {recall_hit}/{num_episodes_returned} = {recall_hit/num_episodes_returned*100:.2f}%") + # print(f"Average episodes returned per question: {num_episodes_returned}/{num_questions} = {num_episodes_returned/num_questions:.2f}") + # print(f"Number of questions used KG: {num_used_kg}/{num_questions} = {num_used_kg/num_questions*100:.2f}%") + # print(f"Number of questions using long-term memory only: {num_used_longterm_only}/{num_questions} = {num_used_longterm_only/num_questions*100:.2f}%") + # print(f"Number of questions using both long-term and KG memory: {num_used_both}/{num_questions} = {num_used_both/num_questions*100:.2f}%") + with open(target_path, "a") as f: json.dump(results, f, indent=4) diff --git a/evaluation/locomo/episodic_memory/run.sh b/evaluation/locomo/episodic_memory/run.sh new file mode 100755 index 000000000..2a077687b --- /dev/null +++ b/evaluation/locomo/episodic_memory/run.sh @@ -0,0 +1,15 @@ +TEST_NAME=$1 +RESULT_FILE="./result/output_${TEST_NAME}.json" +EVAL_FILE="./result/evaluation_metrics_${TEST_NAME}.json" +FINAL_SCORE_FILE="./result/${TEST_NAME}.result" + +rm -f $RESULT_FILE $EVAL_FILE $FINAL_SCORE_FILE + +set -xe +export OPENAI_API_KEY=key +export PYTHONUNBUFFERED=1 +#python3 -u locomo_ingest.py --data-path ../locomo10.json | tee ingest.log +python -u locomo_search.py --data-path ../locomo10.json --target-path $RESULT_FILE +python locomo_evaluate.py --input_file $RESULT_FILE --output_file $EVAL_FILE +python generate_scores.py $EVAL_FILE > $FINAL_SCORE_FILE +cat $FINAL_SCORE_FILE diff --git a/evaluation/locomo/locomo10.json b/evaluation/locomo/locomo10.json index d04e59eac..81b4c7c17 100644 --- a/evaluation/locomo/locomo10.json +++ b/evaluation/locomo/locomo10.json @@ -44,7 +44,7 @@ }, { "question": "When did Melanie run a charity race?", - "answer": "The sunday before 25 May 2023", + "answer": "The Saturday before 25 May 2023", "evidence": [ "D2:1" ], @@ -265,7 +265,7 @@ "question": "When did Melanie go camping in June?", "answer": "The week before 27 June 2023", "evidence": [ - "D4:8" + "D4:6" ], "category": 2 }, @@ -17356,7 +17356,6 @@ "answer": "\"Eternal Sunshineof the Spotless Mind\"", "evidence": [ "D1:18", - "D", "D1:20" ], "category": 4 @@ -24164,7 +24163,7 @@ "D2:7", "D4:7", "D5:15", - "D:11:26", + "D11:26", "D20:21", "D26:36" ], @@ -54209,7 +54208,7 @@ "question": "How might Evan and Sam's experiences with health and lifestyle changes influence their approach to stress and challenges?", "answer": "Their experiences likely lead them to view challenges as opportunities for growth and change. They both have embraced healthier lifestyles, indicating a proactive approach to managing stress and challenges.", "evidence": [ - "D9:1 D4:4 D4:6" + "D9:1;D4:4;D4:6" ], "category": 3 }, @@ -54276,7 +54275,7 @@ "question": "What role does nature and the outdoors play in Evan and Sam's mental well-being?", "answer": "Nature and outdoor activities seem to be significant stress relievers and sources of joy for both Evan and Sam. These activities likely contribute positively to their mental well-being.", "evidence": [ - "D22:1 D22:2 D9:10 D9:11" + "D22:1;D22:2;D9:10;D9:11" ], "category": 3 }, @@ -54352,7 +54351,7 @@ "question": "How do Evan and Sam use creative outlets to cope with life's challenges?", "answer": "Evan and Sam use creative activities, like painting and writing, as therapeutic tools to express themselves and cope with stress.", "evidence": [ - "D21:18 D21:22 D11:15 D11:19" + "D21:18,D21:22,D11:15,D11:19" ], "category": 3 }, diff --git a/src/memmachine/common/embedder/openai_embedder.py b/src/memmachine/common/embedder/openai_embedder.py index 5e763b81c..cf965de13 100644 --- a/src/memmachine/common/embedder/openai_embedder.py +++ b/src/memmachine/common/embedder/openai_embedder.py @@ -70,7 +70,7 @@ def __init__(self, config: dict[str, Any]): self._model = model - temp_client = openai.OpenAI(api_key=api_key, base_url=config.get("base_url")) + temp_client = openai.OpenAI(api_key=api_key, base_url=config.get("base_url"), timeout=600) # https://platform.openai.com/docs/guides/embeddings#embedding-models dimensions = config.get("dimensions") @@ -111,7 +111,7 @@ def __init__(self, config: dict[str, Any]): self._dimensions = dimensions self._client = openai.AsyncOpenAI( - api_key=api_key, base_url=config.get("base_url") + api_key=api_key, base_url=config.get("base_url"), timeout=600, ) metrics_factory = config.get("metrics_factory") diff --git a/src/memmachine/common/reranker/reranker_builder.py b/src/memmachine/common/reranker/reranker_builder.py index e3a856fb9..b6ae9ebf8 100644 --- a/src/memmachine/common/reranker/reranker_builder.py +++ b/src/memmachine/common/reranker/reranker_builder.py @@ -70,6 +70,9 @@ def build( from .bm25_reranker import BM25Reranker, BM25RerankerParams + import nltk + nltk.download('stopwords') + language = config.get("language", "english") stop_words = stopwords.words(language) diff --git a/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py b/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py index 9b132ec1c..108981574 100644 --- a/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py +++ b/src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py @@ -10,10 +10,12 @@ import re from collections.abc import Awaitable, Collection, Mapping from typing import Any, cast +import unicodedata from uuid import UUID +from collections import defaultdict from neo4j import AsyncDriver -from neo4j.graph import Node as Neo4jNode +from neo4j.graph import Node as Neo4jNode, Relationship as Neo4jRelationship from neo4j.time import DateTime as Neo4jDateTime from pydantic import BaseModel, Field, InstanceOf @@ -78,6 +80,77 @@ def __init__(self, params: Neo4jVectorGraphStoreParams): self._vector_index_name_cache: set[str] = set() + @async_locked + async def create_fulltext_index(self): + relation = Neo4jVectorGraphStore._sanitize_name("HAS_RELATION") + field_name = Neo4jVectorGraphStore._sanitize_name("triple_text") + await self._driver.execute_query( + "CREATE FULLTEXT INDEX rel_tripletext_fts IF NOT EXISTS\n" + f"FOR ()-[r:{relation}]-() ON EACH [r.{field_name}]\n" + """ + OPTIONS { + indexConfig: { `fulltext.analyzer`: 'simple' } + }; + """, + ) + + await self._driver.execute_query( + """ + CREATE FULLTEXT INDEX node_name_fts IF NOT EXISTS + FOR (n:Entity) ON EACH [n.name] + OPTIONS { + indexConfig: { `fulltext.analyzer`: 'simple' } + }; + """, + ) + + await self._driver.execute_query( + # Fast lookups by uuid + """ + CREATE CONSTRAINT entity_uuid_unique IF NOT EXISTS + FOR (n:Entity) REQUIRE n.uuid IS UNIQUE; + """, + ) + + await self._driver.execute_query( + # Fast lookups by uuid + """ + CREATE CONSTRAINT episode_uuid_unique IF NOT EXISTS + FOR (n:Episode) REQUIRE n.uuid IS UNIQUE; + """, + ) + + await self._driver.execute_query( + # Fast lookups by uuid + """ + CREATE CONSTRAINT episode_uuid_unique IF NOT EXISTS + FOR (n:EpisodeCluster) REQUIRE n.uuid IS UNIQUE; + """, + ) + + session_id = Neo4jVectorGraphStore._sanitize_name("session_id") + await self._driver.execute_query( + f""" + CREATE INDEX episode_session_id_idx IF NOT EXISTS + FOR (n:EpisodeCluster) ON (n.{session_id}); + """, + ) + + await self._driver.execute_query( + f""" + CREATE INDEX episode_session_id_idx IF NOT EXISTS + FOR (n:Episode) ON (n.{session_id}); + """, + ) + + await self._driver.execute_query( + f""" + CREATE INDEX entity_session_id_idx IF NOT EXISTS + FOR (n:Entity) ON (n.{session_id}); + """, + ) + + async def add_nodes(self, nodes: Collection[Node]): labels_nodes_map: dict[tuple[str, ...], list[Node]] = {} for node in nodes: @@ -272,6 +345,7 @@ async def search_related_nodes( required_labels: Collection[str] | None = None, required_properties: Mapping[str, Property] = {}, include_missing_properties: bool = False, + index_search_label: str = "", ) -> list[Node]: if not (find_sources or find_targets): return [] @@ -285,21 +359,30 @@ async def search_related_nodes( else ["[]"] ) + session_id_index = "" + session_id = "" + required_properties_copy = dict(required_properties) + if "session_id" in required_properties_copy: + session_id = required_properties_copy["session_id"] + session_id_index = f"{{{Neo4jVectorGraphStore._sanitize_name('session_id')}: $session_id}}" + required_properties_copy.pop("session_id") + search_related_nodes_tasks = [ async_with( self._semaphore, self._driver.execute_query( - "MATCH\n" - " (m {uuid: $node_uuid})" - f" {'-' if find_targets else '<-'}" - f" {query_typed_relation}" - f" {'-' if find_sources else '->'}" - f" (n{Neo4jVectorGraphStore._format_labels(required_labels)})" + f"MATCH (m{index_search_label} {{uuid: $node_uuid}})\n" + f"MATCH (n{Neo4jVectorGraphStore._format_labels(required_labels)} {session_id_index})\n" + f"MATCH (m)" + f" {'-' if find_targets else '<-'}" + f" {query_typed_relation}" + f" {'-' if find_sources else '->'} " + f"(n{Neo4jVectorGraphStore._format_labels(required_labels)})\n" f"WHERE { Neo4jVectorGraphStore._format_required_properties( 'n', - required_properties, - include_missing_properties, + required_properties_copy, + include_missing_properties ) }\n" "RETURN n\n" @@ -308,8 +391,10 @@ async def search_related_nodes( limit=limit, required_properties={ Neo4jVectorGraphStore._sanitize_name(key): value - for key, value in required_properties.items() + for key, value in required_properties_copy.items() }, + session_id=session_id, + timeout=3600, ), ) for query_typed_relation in query_typed_relations @@ -326,6 +411,88 @@ async def search_related_nodes( return list(related_nodes)[:limit] + async def search_related_nodes_edges_batch( + self, + node_uuids: list[UUID], + allowed_relations: Collection[str] | None = None, + find_sources: bool = True, + find_targets: bool = True, + limit: int | None = None, + required_labels: Collection[str] | None = None, + required_properties: Mapping[str, Property] = {}, + include_missing_properties: bool = False, + index_search_label: str = "", + ) -> tuple[list[Node], list[Edge]]: + if not (find_sources or find_targets): + return [] + + query_typed_relations = ( + [ + f"[r:{Neo4jVectorGraphStore._sanitize_name(relation)}]" + for relation in allowed_relations + ] + if allowed_relations is not None + else ["[r]"] + ) + + session_id_index = "" + session_id = "" + required_properties_copy = dict(required_properties) + if "session_id" in required_properties_copy: + session_id = required_properties_copy["session_id"] + session_id_index = f"{{{Neo4jVectorGraphStore._sanitize_name('session_id')}: $session_id}}" + required_properties_copy.pop("session_id") + + search_related_nodes_tasks = [ + async_with( + self._semaphore, + self._driver.execute_query( + "UNWIND $uuids AS id\n" + f"MATCH (m{index_search_label} {{uuid: id}})\n" + f"MATCH (n{Neo4jVectorGraphStore._format_labels(required_labels)} {session_id_index})\n" + f"MATCH (m)" + f" {'-' if find_targets else '<-'}" + f" {query_typed_relation}" + f" {'-' if find_sources else '->'} " + f"(n)\n" + f"WHERE { + Neo4jVectorGraphStore._format_required_properties( + 'n', + required_properties_copy, + include_missing_properties + ) + }\n" + "RETURN n, r, startNode(r) AS src, endNode(r) AS dst\n" + f"{'LIMIT $limit' if limit is not None else ''}", + uuids=[str(node_uuid) for node_uuid in node_uuids], + limit=limit, + required_properties={ + Neo4jVectorGraphStore._sanitize_name(key): value + for key, value in required_properties_copy.items() + }, + session_id=session_id, + timeout=3600, + ), + ) + for query_typed_relation in query_typed_relations + ] + + results = await asyncio.gather(*search_related_nodes_tasks) + + related_nodes: set[Node] = set() + related_edges: set[Edge] = set() + for records, _, _ in results: + related_neo4j_nodes = [record["n"] for record in records] + related_nodes.update( + Neo4jVectorGraphStore._nodes_from_neo4j_nodes(related_neo4j_nodes) + ) + related_neo4j_relations = [record["r"] for record in records] + related_edges.update( + Neo4jVectorGraphStore._edges_from_neo4j_relationships(related_neo4j_relations) + ) + + return list(related_nodes), list(related_edges) + async def search_directional_nodes( self, by_property: str, @@ -403,6 +570,344 @@ async def search_matching_nodes( matching_neo4j_nodes = [record["n"] for record in records] return Neo4jVectorGraphStore._nodes_from_neo4j_nodes(matching_neo4j_nodes) + def lucene_sanitize(self, query: str) -> str: + # Escape special characters from a query before passing into Lucene + # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ / + escape_map = str.maketrans( + { + '+': r'\+', + '-': r'\-', + '&': r'\&', + '|': r'\|', + '!': r'\!', + '(': r'\(', + ')': r'\)', + '{': r'\{', + '}': r'\}', + '[': r'\[', + ']': r'\]', + '^': r'\^', + '"': r'\"', + '~': r'\~', + '*': r'\*', + '?': r'\?', + ':': r'\:', + '\\': r'\\', + '/': r'\/', + 'O': r'\O', + 'R': r'\R', + 'N': r'\N', + 'T': r'\T', + 'A': r'\A', + 'D': r'\D', + } + ) + return query.translate(escape_map) + + def normalize_name(self, s: str) -> str: + s = unicodedata.normalize("NFKC", s) + s = s.lower() + s = re.sub(r"\s+", " ", s).strip() + return s + + def rrf( + self, + lists: list[list[Any]], + weights: list[float] | None = None, + k: int = 50, + ) -> list[tuple[Any, float]]: + """ + Reciprocal Rank Fusion (RRF) implementation. + Args: + lists (list[list[str]]): List of ranked lists to fuse. + weights (list[float] | None): Weights for each ranked list. If None, equal weights are used. + k (int): Constant to control the influence of rank. Smaller k gives more weight to higher ranks. + Returns: + list[tuple[str, float]]: List of tuples containing item and its fused score, sorted by score in descending order. + """ + if not lists: + return [] + + n = len(lists) + if weights is None: + weights = [1.0] * n + elif len(weights) != n: + raise ValueError(f"RRF: weights length {len(weights)} must equal number of lists {n}") + + scores: dict[str, float] = defaultdict(float) + + for w, L in zip(weights, lists): + for rank, _id in enumerate(L, start=1): + if w > 0: + scores[_id] += w * (1.0 / (k + rank)) + + return sorted(scores.items(), key=lambda x: x[1], reverse=True) + + async def search_similar_edges( + self, + query_text: str, + query_embedding: list[float], + embedding_property_name: str, + similarity_threshold: float = 0.2, + limit: int | None = 100, + allowed_relations: set[str] | None = None, + required_properties: dict[str, Property] = {}, + include_missing_properties: bool = False, + ) -> list[Edge]: + query_typed_relations = ( + [ + f":{Neo4jVectorGraphStore._sanitize_name(relation)}" + for relation in allowed_relations + ] + if allowed_relations is not None + else [""] + ) + + sanitized_embedding_property_name = Neo4jVectorGraphStore._sanitize_name( + embedding_property_name + ) + + vector_index_name = ( + Neo4jVectorGraphStore._edge_vector_index_name( + Neo4jVectorGraphStore._sanitize_name(next(iter(allowed_relations))), + sanitized_embedding_property_name, + ) + if allowed_relations is not None and len(allowed_relations) > 0 + else None + ) + + await self._create_edge_vector_index_if_not_exist( + relations=cast(Collection[str], allowed_relations), + embedding_property_name=embedding_property_name, + dimensions=len(query_embedding), + similarity_metric=SimilarityMetric.COSINE, + ) + + vector_search_tasks = [ + async_with( + self._semaphore, + self._driver.execute_query( + f""" + CALL {{ + CALL db.index.vector.queryRelationships($vector_index_name, $prelimit, $query_embedding) + YIELD relationship AS r, score + WHERE score > $similarity_threshold + AND { + Neo4jVectorGraphStore._format_required_properties( + "r", required_properties, include_missing_properties + ) + } + ORDER BY score DESC + LIMIT $limit + RETURN r, score + }} + RETURN r, startNode(r) AS src, endNode(r) AS dst, score AS similarity + ORDER BY similarity DESC + """, + vector_index_name=vector_index_name, + query_embedding=query_embedding, + similarity_threshold=similarity_threshold, + limit=limit, + prelimit=(limit * 5) if limit is not None else 100, + required_properties={ + Neo4jVectorGraphStore._sanitize_name(key): value + for key, value in required_properties.items() + }, + include_missing_properties=include_missing_properties, + timeout=3600, + ), + ) + for relation in query_typed_relations + ] + + results = await asyncio.gather(*vector_search_tasks) + + vector_search_edges = [] + edge_source_node_map = {} + for records, _, _ in results: + rels = [record["r"] for record in records] + vector_search_edges.extend( + Neo4jVectorGraphStore._edges_from_neo4j_relationships(rels) + ) + neo4j_nodes = [record["src"] for record in records] + src_nodes = Neo4jVectorGraphStore._nodes_from_neo4j_nodes(neo4j_nodes) + edge_source_node_map.update({ + n.uuid: n for n in src_nodes + }) + + fulltext_search_tasks = [ + async_with( + self._semaphore, + self._driver.execute_query( + f""" + CALL {{ + CALL db.index.fulltext.queryRelationships('rel_tripletext_fts', $q_text, {{limit: $prelimit}}) + YIELD relationship AS r, score + WHERE { + Neo4jVectorGraphStore._format_required_properties( + "r", required_properties, include_missing_properties + ) + } + ORDER BY score DESC + LIMIT $limit + RETURN r, score + }} + RETURN r, startNode(r) AS src, endNode(r) AS dst, score AS similarity + ORDER BY similarity DESC + """, + q_text=self.lucene_sanitize(self.normalize_name(query_text)), # see normalization helper above + limit=limit, + prelimit=(limit * 5) if limit is not None else 100, + required_properties={ + Neo4jVectorGraphStore._sanitize_name(key): value + for key, value in required_properties.items() + }, + include_missing_properties=include_missing_properties, + timeout=3600, + ), + ) + for relation in query_typed_relations + ] + + results = await asyncio.gather(*fulltext_search_tasks) + + fulltext_search_edges = [] + for records, _, _ in results: + rels = [record["r"] for record in records] + fulltext_search_edges.extend( + Neo4jVectorGraphStore._edges_from_neo4j_relationships(rels) + ) + neo4j_nodes = [record["src"] for record in records] + src_nodes = Neo4jVectorGraphStore._nodes_from_neo4j_nodes(neo4j_nodes) + edge_source_node_map.update({ + n.uuid: n for n in src_nodes + }) + + fused = self.rrf([vector_search_edges, fulltext_search_edges], k=50) + result_edges = [edge for edge, _ in fused][:limit] + return result_edges, [edge_source_node_map[edge.source_uuid] for edge in result_edges] + + # async def hybrid_search_edges( + # self, + # query_text: str, + # query_embedding: list[float], + # similarity_threshold: float = 0.6, + # limit: int | None = 100, + # required_properties: dict[str, Property] = {}, + # include_missing_properties: bool = False, + # ) -> list[Edge]: + # async with self._semaphore: + # records, _, _ = await self._driver.execute_query( + # f""" + # CALL db.index.fulltext.queryRelationships('rel_tripletext_fts', $q_text) + # YIELD relationship AS r, score AS bm25 + # WITH r, bm25, + # CASE WHEN r.embedding IS NOT NULL + # THEN vector.similarity.cosine(r.embedding, $query_embedding) + # ELSE 0.0 + # END AS cos + # WHERE cos > $similarity_threshold + # MATCH p = ()-[r:RELATED_TO]-() + # WITH p, bm25, cos, 0.6*cos + 0.4*bm25 AS hybrid + # RETURN p + # ORDER BY hybrid DESC + # {'LIMIT $limit' if limit is not None else ''} + # """, + # q_text=self.lucene_sanitize(self.normalize_name(query_text)), # see normalization helper above + # query_embedding=query_embedding, + # similarity_threshold=similarity_threshold, # e.g., 0.25 + # limit=limit, + # ) + + # neo4j_paths = [record["p"] for record in records] + # return Neo4jVectorGraphStore._edges_from_neo4j_relationships([p.relationships[0] for p in neo4j_paths]) + + async def hybrid_search_nodes( + self, + node_name: str, + rrf_weights: list[float] | None, + limit: int = 3, + required_label: str = "Entity", + required_properties: dict[str, Property] = {}, + include_missing_properties: bool = False, + ) -> list[Node]: + fulltext_nodes = [] + substr_nodes = [] + + async with self._semaphore: + # Full text search + records, _, _ = await self._driver.execute_query( + f""" + CALL {{ + CALL db.index.fulltext.queryNodes('node_name_fts', $name, {{limit: $prelimit}}) + YIELD node AS n, score + WHERE { + Neo4jVectorGraphStore._format_required_properties( + "n", required_properties, include_missing_properties + ) + }\n + RETURN n, score + ORDER BY score DESC + LIMIT $limit + }} + RETURN n + ORDER BY score DESC + """, + name=self.lucene_sanitize(self.normalize_name(node_name)), + limit=limit, + prelimit=(limit * 5) if limit is not None else 100, + labels=[required_label], + required_properties={ + Neo4jVectorGraphStore._sanitize_name(key): value + for key, value in required_properties.items() + }, + include_missing_properties=include_missing_properties, + timeout=3600, + ) + + neo4j_nodes = [record["n"] for record in records] + fulltext_nodes = Neo4jVectorGraphStore._nodes_from_neo4j_nodes(neo4j_nodes) + + # TODO: This is slow, consider using vector search here instead + substr_nodes = [] + # if rrf_weights[1] != 0: + # async with self._semaphore: + # # Substring search + # records, _, _ = await self._driver.execute_query( + # f""" + # MATCH (n:{required_label}) + # WHERE n.name IS NOT NULL + # AND toLower(n.name) CONTAINS toLower($name) + # AND { + # Neo4jVectorGraphStore._format_required_properties( + # "n", required_properties, include_missing_properties + # ) + # }\n + # RETURN n + # {'LIMIT $limit' if limit is not None else ''} + # """, + # name=node_name, + # limit=limit, + # required_properties={ + # Neo4jVectorGraphStore._sanitize_name(key): value + # for key, value in required_properties.items() + # }, + # include_missing_properties=include_missing_properties, + # timeout=3600, + # ) + # neo4j_nodes = [record["n"] for record in records] + # substr_nodes = Neo4jVectorGraphStore._nodes_from_neo4j_nodes(neo4j_nodes) + + node_map = {} + for node in fulltext_nodes + substr_nodes: + if node.uuid not in node_map: + node_map[node.uuid] = node + ft_nodes_uuids = [n.uuid for n in fulltext_nodes] + ss_nodes_uuids = [n.uuid for n in substr_nodes] + + fused = self.rrf([ft_nodes_uuids, ss_nodes_uuids], weights=rrf_weights, k=50) + return [node_map[uuid] for uuid, _ in fused][:limit] + async def delete_nodes( self, node_uuids: Collection[UUID], @@ -514,6 +1019,97 @@ async def _create_node_vector_index_if_not_exist( await self._execute_create_node_vector_index_if_not_exist(create_index_tasks) self._vector_index_name_cache.update(requested_vector_index_names) + + async def _create_edge_vector_index_if_not_exist( + self, + relations: Collection[str], + embedding_property_name: str, + dimensions: int, + similarity_metric: SimilarityMetric = SimilarityMetric.COSINE, + ): + """ + Create node vector index(es) if not exist. + + Args: + labels (Collection[str]): + Collection of node labels to create vector indexes for. + embedding_property_name (str): + Name of the embedding property. + dimensions (int): + Dimensionality of the embedding vectors. + similarity_metric (SimilarityMetric): + Similarity metric to use for the vector index + (default: SimilarityMetric.COSINE). + """ + if not self._vector_index_name_cache: + async with self._semaphore: + records, _, _ = await self._driver.execute_query( + "SHOW VECTOR INDEXES YIELD name RETURN name" + ) + + self._vector_index_name_cache.update(record["name"] for record in records) + + sanitized_relations = [ + Neo4jVectorGraphStore._sanitize_name(relation) for relation in relations + ] + + sanitized_embedding_property_name = Neo4jVectorGraphStore._sanitize_name( + embedding_property_name + ) + + requested_vector_index_names = [ + Neo4jVectorGraphStore._edge_vector_index_name( + sanitized_relation, sanitized_embedding_property_name + ) + for sanitized_relation in sanitized_relations + ] + + info_for_vector_indexes_to_create = [ + (sanitized_relation, sanitized_embedding_property_name, vector_index_name) + for sanitized_relation, vector_index_name in zip( + sanitized_relations, + requested_vector_index_names, + ) + if vector_index_name not in self._vector_index_name_cache + ] + + if len(info_for_vector_indexes_to_create) == 0: + return + + match similarity_metric: + case SimilarityMetric.COSINE: + similarity_function = "cosine" + case SimilarityMetric.EUCLIDEAN: + similarity_function = "euclidean" + case _: + similarity_function = "cosine" + + create_index_tasks = [ + async_with( + self._semaphore, + self._driver.execute_query( + f"CREATE VECTOR INDEX {vector_index_name}\n" + "IF NOT EXISTS\n" + f"FOR ()-[r:{sanitized_relation}]-()\n" + f"ON r.{sanitized_embedding_property_name}\n" + "OPTIONS {\n" + " indexConfig: {\n" + " `vector.dimensions`:\n" + " $dimensions,\n" + " `vector.similarity_function`:\n" + " $similarity_function\n" + " }\n" + "}", + dimensions=dimensions, + similarity_function=similarity_function, + ), + ) + for sanitized_relation, sanitized_embedding_property_name, vector_index_name in info_for_vector_indexes_to_create + ] + + await self._execute_create_node_vector_index_if_not_exist(create_index_tasks) + + self._vector_index_name_cache.update(requested_vector_index_names) @async_locked async def _execute_create_node_vector_index_if_not_exist( @@ -682,6 +1278,66 @@ def _nodes_from_neo4j_nodes( for neo4j_node in neo4j_nodes ] + @staticmethod + def _edge_vector_index_name( + sanitized_label: str, sanitized_embedding_property_name: str + ) -> str: + """ + Generate a unique name for a node vector index + based on the label and embedding property name. + + Args: + sanitized_label (str): + The sanitized node label. + embedding_property_name (str): + The sanitized embedding property name. + + Returns: + str: The generated vector index name. + """ + return ( + "edge_vector_index" + "_for_" + f"{len(sanitized_label)}_" + f"{sanitized_label}" + "_on_" + f"{len(sanitized_embedding_property_name)}_" + f"{sanitized_embedding_property_name}" + ) + + @staticmethod + def _edges_from_neo4j_relationships( + neo4j_relationships: list[Neo4jRelationship] + ) -> list[Edge]: + """ + Convert a list of Neo4j Relationships to a list of Edges. + + Args: + neo4j_relationships (list[Neo4jRelationship]): + List of Neo4j Relationships. + + Returns: + list[Edge]: List of Edges. + """ + return [ + Edge( + uuid=UUID(neo4j_relationship["uuid"]), + source_uuid=UUID(neo4j_relationship.start_node["uuid"]), + target_uuid=UUID(neo4j_relationship.end_node["uuid"]), + relation=neo4j_relationship.type, + properties={ + Neo4jVectorGraphStore._desanitize_name( + key + ): Neo4jVectorGraphStore._python_value_from_neo4j_value( + value + ) + for key, value in neo4j_relationship.items() + if key != "uuid" + }, + ) + for neo4j_relationship in neo4j_relationships + ] + @staticmethod def _python_value_from_neo4j_value(value: Any) -> Any: """ diff --git a/src/memmachine/episodic_memory/episodic_memory.py b/src/memmachine/episodic_memory/episodic_memory.py index 02cc3b20c..dfdb01b88 100644 --- a/src/memmachine/episodic_memory/episodic_memory.py +++ b/src/memmachine/episodic_memory/episodic_memory.py @@ -19,9 +19,13 @@ import asyncio import copy import logging -import uuid +from uuid import UUID from datetime import datetime from typing import cast +import os +from rank_bm25 import BM25Okapi +import re +import time from memmachine.common.language_model.language_model_builder import ( LanguageModelBuilder, @@ -31,11 +35,32 @@ ) from .data_types import ContentType, Episode, MemoryContext +from memmachine.common.vector_graph_store import Node from .long_term_memory.long_term_memory import LongTermMemory from .short_term_memory.session_memory import SessionMemory -logger = logging.getLogger(__name__) +os.environ["DSPY_CACHEDIR"] = "/tmp/dspy_cache" +import dspy +logger = logging.getLogger(__name__) +dsyp_init = False + +import nltk + +print("Checking for required NLTK data...") +packages = [ + ("tokenizers/punkt", "punkt"), + ("tokenizers/punkt_tab", "punkt_tab"), + ("corpora/stopwords", "stopwords"), +] +for path, pkg_id in packages: + try: + nltk.data.find(path) + print(f"✅ - NLTK package '{pkg_id}' is already installed.") + except LookupError: + print(f"⚠️ - NLTK package '{pkg_id}' not found. Downloading...") + nltk.download(pkg_id) +print("\nNLTK data setup is complete. ✨") class EpisodicMemory: # pylint: disable=too-many-instance-attributes @@ -131,6 +156,12 @@ def __init__(self, manager, config: dict, memory_context: MemoryContext): "query_count", "Count of query processing" ) + global dsyp_init + if not dsyp_init: + # Check if DSPy is available + dspy.configure_cache(enable_disk_cache=True, enable_memory_cache=True, disk_cache_dir="/tmp/dspy_cache") + print("✅ DSPy available") + @property def short_term_memory(self) -> SessionMemory | None: """ @@ -201,6 +232,15 @@ async def reference(self) -> bool: return False self._ref_count += 1 return True + + def group_id(self) -> str: + """ + Get the group ID of the memory context. + + Returns: + The group ID as a string. + """ + return self._memory_context.group_id async def add_memory_episode( self, @@ -211,6 +251,7 @@ async def add_memory_episode( content_type: ContentType, timestamp: datetime | None = None, metadata: dict | None = None, + uuid: UUID | None = None, ): # pylint: disable=too-many-arguments # pylint: disable=too-many-positional-arguments @@ -257,8 +298,10 @@ async def add_memory_episode( start_time = datetime.now() # Create a new Episode object + if uuid is None: + raise ValueError("UUID must be provided for the episode") episode = Episode( - uuid=uuid.uuid4(), + uuid=uuid, episode_type=episode_type, content_type=content_type, content=episode_content, @@ -270,12 +313,24 @@ async def add_memory_episode( user_metadata=metadata, ) + # kg_episode = Node( + # uuid=cur_uuid, + # labels={"Episode"}, + # # Make timestamp different for each episode + # properties={ + # "content": episode_content, + # "timestamp": timestamp if timestamp else datetime.now(), + # "session_id": self._memory_context.group_id + # }, + # ) + # Add the episode to both memory stores concurrently tasks = [] if self._session_memory: tasks.append(self._session_memory.add_episode(episode)) if self._long_term_memory: tasks.append(self._long_term_memory.add_episode(episode)) + # tasks.append(add_episode_bulk([kg_episode])) await asyncio.gather( *tasks, ) @@ -361,50 +416,124 @@ async def query_memory( # By default, always allow cross session search property_filter["group_id"] = self._memory_context.group_id - async with self._lock: - if self._session_memory is None: - short_episode: list[Episode] = [] - short_summary = "" - long_episode = await cast( - LongTermMemory, self._long_term_memory - ).search( - query, - search_limit, - property_filter, - ) - elif self._long_term_memory is None: - session_result = await self._session_memory.get_session_memory_context( + kg_episodes = [] + short_episode= [] + short_summary = "" + long_episode = [] + # async with self._lock: + if self._session_memory is None: + short_episode: list[Episode] = [] + short_summary = "" + long_episode = await cast( + LongTermMemory, self._long_term_memory + ).search( + query, + search_limit, + property_filter, + ) + # _, kg_episodes = await search_kg(query, self._memory_context.group_id, limit=500) + elif self._long_term_memory is None: + session_result = await asyncio.gather( + self._session_memory.get_session_memory_context( + query, limit=search_limit + ), + ) + long_episode = [] + short_episode, short_summary = session_result + else: + # Concurrently search both memory stores + session_result, long_episode = await asyncio.gather( + self._session_memory.get_session_memory_context( query, limit=search_limit - ) - long_episode = [] - short_episode, short_summary = session_result - else: - # Concurrently search both memory stores - session_result, long_episode = await asyncio.gather( - self._session_memory.get_session_memory_context( - query, limit=search_limit - ), - self._long_term_memory.search(query, search_limit, property_filter), - ) - short_episode, short_summary = session_result + ), + self._long_term_memory.search(query, search_limit, property_filter), + ) + short_episode, short_summary = session_result + + # _, kg_episodes = await search_kg(query, self._memory_context.group_id, limit=500) + # print(f"KG Search got {len(kg_episodes)} KG episodes.") # Deduplicate episodes from both memory stores, prioritizing # short-term memory + unique_long_episodes = [] uuid_set = {episode.uuid for episode in short_episode} - unique_long_episodes = [] for episode in long_episode: if episode.uuid not in uuid_set: uuid_set.add(episode.uuid) unique_long_episodes.append(episode) + # def tokenize(s): # simple tokenizer; adapt as needed + # return re.findall(r"[A-Za-z0-9_'-]+", s.lower()) + + # def split_sentences(text): + # # lightweight splitter; replace with nltk/sentencepiece/spacy if you like + # return [s.strip() for s in re.split(r'(?<=[\.!?])\s+', text) if s.strip()] + + # def top_k_support_bm25(episodes, query, k=5): + # if len(episodes) == 0: + # return [] + + # sents = [] + # uuids = [] + # uuid_to_episode = {} + # for e in episodes: + # for line in e.properties["content"].split("\n"): + # cur_sents = split_sentences(line) + # for s in cur_sents: + # if s == "": + # continue + # if tokenize(s) == []: + # continue + # sents.append(s) + # uuids.append(e.uuid) + # uuid_to_episode[e.uuid] = e + + # # sents = split_sentences(context) + # corpus = [tokenize(s) for s in sents] + # bm25 = BM25Okapi(corpus) + # scores = bm25.get_scores(tokenize(query)) + # top_idx = sorted(range(len(sents)), key=lambda i: scores[i], reverse=True)[:k] + # return [uuid_to_episode[uuids[i]] for i in top_idx] + + result_kg_episodes = [] + + # ts = time.perf_counter() + # top_kg_episodes = top_k_support_bm25(kg_episodes, query, k=search_limit*2) + # print(f"BM25 ranking took {time.perf_counter() - ts:.4f} seconds.") + + # Filter duplicate kg_episodes + # num_dup = 0 + # for e in top_kg_episodes: + # if e.uuid in uuid_set: + # num_dup += 1 + # continue + # uuid_set.add(e.uuid) + # result_kg_episodes.append(e) + + num_dup = 0 + for e in kg_episodes: + if e.uuid in uuid_set: + num_dup += 1 + continue + uuid_set.add(e.uuid) + result_kg_episodes.append(e) + + # print(f"Filtered {num_dup} duplicate KG episodes.") + # result_kg_episodes = result_kg_episodes[-search_limit:] + # result_kg_episodes = sorted( + # result_kg_episodes, + # key=lambda e: (e.properties.get('timestamp') is None, + # e.properties.get('timestamp')) + # ) + end_time = datetime.now() delta = end_time - start_time self._query_latency_summary.observe( delta.total_seconds() * 1000 + delta.microseconds / 1000 ) self._query_counter.increment() - return short_episode, unique_long_episodes, [short_summary] + return short_episode, unique_long_episodes, [short_summary], result_kg_episodes async def formalize_query_with_context( self, diff --git a/src/memmachine/episodic_memory/episodic_memory_manager.py b/src/memmachine/episodic_memory/episodic_memory_manager.py index de9adb35e..d3ab1f2ea 100644 --- a/src/memmachine/episodic_memory/episodic_memory_manager.py +++ b/src/memmachine/episodic_memory/episodic_memory_manager.py @@ -396,36 +396,36 @@ async def get_episodic_memory_instance( session_id=session_id, ) - async with self._lock: - # If an instance for this context already exists, increment its - # reference count and return it. - if context in self._context_memory: - instance = self._context_memory[context] - get_it = await instance.reference() - if get_it: - return instance - # The instance was closed between checking and referencing. - logger.error("Failed get instance reference") - return None - # If no instance exists, create a new one. - try: - info = self._session_manager.open_session(group_id, session_id) - final_config = info.configuration - except ValueError: - if configuration is None: - configuration = {} - final_config = self._merge_configs(self._memory_config, configuration) - info = self._session_manager.create_session_if_not_exist( - group_id, agent_id, user_id, session_id, final_config - ) - - # Create and store the new memory instance. - memory_instance = EpisodicMemory(self, final_config, context) + # async with self._lock: + # If an instance for this context already exists, increment its + # reference count and return it. + if context in self._context_memory: + instance = self._context_memory[context] + get_it = await instance.reference() + if get_it: + return instance + # The instance was closed between checking and referencing. + logger.error("Failed get instance reference") + return None + # If no instance exists, create a new one. + try: + info = self._session_manager.open_session(group_id, session_id) + final_config = info.configuration + except ValueError: + if configuration is None: + configuration = {} + final_config = self._merge_configs(self._memory_config, configuration) + info = self._session_manager.create_session_if_not_exist( + group_id, agent_id, user_id, session_id, final_config + ) - self._context_memory[context] = memory_instance + # Create and store the new memory instance. + memory_instance = EpisodicMemory(self, final_config, context) - await memory_instance.reference() - return memory_instance + self._context_memory[context] = memory_instance + + await memory_instance.reference() + return memory_instance async def delete_context_memory(self, context: MemoryContext): """ diff --git a/src/memmachine/knowledge_graph/knowledge_graph.py b/src/memmachine/knowledge_graph/knowledge_graph.py new file mode 100644 index 000000000..44be9a1ad --- /dev/null +++ b/src/memmachine/knowledge_graph/knowledge_graph.py @@ -0,0 +1,1371 @@ +import os +import json +import re +import time, random +from datetime import datetime, timezone, timedelta +from collections import defaultdict, deque +from collections.abc import Collection +from typing import List, Tuple, Dict, Any, Iterable, DefaultDict +from rank_bm25 import BM25Okapi +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize, sent_tokenize +import boto3 + +import asyncio +from uuid import uuid4, UUID + +from dotenv import load_dotenv +from neo4j import AsyncGraphDatabase +from openai import AsyncOpenAI + +from memmachine.common.utils import async_locked +from memmachine.common.vector_graph_store import Node, Edge +from memmachine.common.vector_graph_store.neo4j_vector_graph_store import Neo4jVectorGraphStore, Neo4jVectorGraphStoreParams +from memmachine.common.embedder.openai_embedder import OpenAIEmbedder +from memmachine.common.language_model.openai_language_model import OpenAILanguageModel +from memmachine.common.reranker.amazon_bedrock_reranker import ( + AmazonBedrockReranker, + AmazonBedrockRerankerParams, +) + +os.environ["DSPY_CACHEDIR"] = "/tmp/dspy_cache" +import dspy + +EPISODE_GROUP_PROMPT = """You are an expert at grouping related episodes by **content similarity**. + +**Goal**: +Analyze the episodes and cluster those that are closely related in content, themes, or topics. Avoid speculative links—use only what's explicitly present. + +**Episodes**: +{episodes} + +**Rules** (be strict and deterministic): + - Treat two episodes as related only if they share a central topic (same event/series, person/organization, product/model, place, study, bug/feature, or storyline). + - Ignore superficial overlap (generic words like “update,” “today,” “news,” dates, or common verbs). + - Prefer **precision over recall**: if unsure, keep the episode as a singleton. + - **Transitivity**: If A is related to B and B to C, put A, B, C in the same group. + - **No overlaps**: Each episode index must appear exactly once in exactly one group. + +**Method** (concise): + 1. **Extract key features** per episode: named entities (people/orgs/places), key nouns/phrases, and specific identifiers (e.g., model numbers, issue IDs). + 2. Score pairwise similarity on a 0–10 scale: + - **10** = near-duplicate + - **8-9** = strongly about the same specific topic + - **6-7** = related or semantically similar (weaker tie) + - **0-5** = unrelated + 3. **Create edges** between episode pairs using this deterministic rule: + - Add an edge if score ≥ 8; or + - If score 6-7 and the pair shares at least one anchor: an exact unique identifier (e.g., “ISSUE-1234”, “M3 MacBook Pro”), or a combined anchor of (same named entity AND same event/date/place). + 4. Take the **transitive closure** (connected components) over edges to form groups. If still ambiguous, keep as a singleton. + 5. Break ties by dominant topic; if still ambiguous, use a singleton. +**Output Format** (strict JSON, no additional text before or after the JSON block): +{{ + "groups": [[0,1], [2,3,4] ...] +}} + +**Examples**: +Example 1: +Episodes: + [0][2025-05-12T09:14:03] Mia: “Should we do Yosemite for Memorial Day? We'll need a backcountry permit.” + [1][2025-05-12T09:40:51] Raj: “I'll submit the permit lottery today—deadline is Friday.” + [2][2025-05-15T17:40:19] Raj: “We lost the lottery.” + [3][2025-05-15T17:41:03] Mia: “Let's try Tuolumne first-come or shift dates.” + [4][2025-05-20T07:52:44] Raj: “If we want Half Dome, we should train for the cables and rent a bear can.” + [5][2025-05-21T19:12:05] Mia: “The espresso machine keeps beeping—descale light won't clear.” + [6][2025-05-21T19:13:48] Raj: “Hold both cup buttons for 5 seconds to reset.” + [7][2025-05-16T10:03:11] Mia: “My React app won't compile—Vite can't find Tailwind classes.” + [8][2025-05-16T11:25:33] Raj: “Add the Tailwind plugin in postcss; rebuild should fix it.” +Output: +{{ + "groups": [[0,1,2,3,4],[5,6],[7,8]] +}} + +Now generate the episode groups based on the provided episodes: +""" + +EPISODE_GROUP_SUMMARY_PROMPT ="""You are a precise dialogue summarizer. Write one concise sentence that captures the main topic and current status of the episodes below. All episodes are about the same topic. + +**Episodes**: +{episodes} + +**Instructions**: + - Use only facts explicitly stated in the episodes; no speculation. + - Preserve exact names/terms (places, products, IDs). + - Focus on the overall goal, key events/changes, and present status/next steps. + - Length: ≤ 30 words. Form: exactly one sentence. + +**Output format**: +Return only the sentence—no quotes, no labels, no extra text. + +**Example**: +Episodes: +[2025-05-12T09:14:03] Mia: “Should we do Yosemite for Memorial Day? We'll need a backcountry permit.” +[2025-05-12T09:40:51] Raj: “I'll submit the permit lottery today—deadline is Friday.” +[2025-05-15T17:40:19] Raj: “We lost the lottery.” +[2025-05-15T17:41:03] Mia: “Let's try Tuolumne first-come or shift dates.” +[2025-05-20T07:52:44] Raj: “If we want Half Dome, we should train for the cables and rent a bear can.” +Output (single string): +Yosemite trip planning: permit lottery submitted then lost; considering Tuolumne or new dates, with training and a bear can for a potential Half Dome attempt. + +Now generate the episode summary based on the provided episodes: +""" + +SUFFICIENCY_CHECK_PROMPT = """You are a meticulous and detail-oriented expert in information retrieval evaluation. Your task is to critically assess whether a set of retrieved documents contains sufficient information to provide a direct and complete answer to a user's query. + +**User Query**: +{query} + +**Retrieved Documents**: +{retrieved_episodes} + +**Instructions**: +Follow these steps to ensure an accurate evaluation: + +1. **Deconstruct the Query**: Break down the user's query into its core informational needs (e.g., who, what, where, when, why, how). Identify all key entities and concepts. +2. **Scan for Keywords**: Quickly scan the documents for the key entities and concepts from the query. This is a preliminary check for relevance. +3. **Detailed Analysis**: Read the relevant parts of the documents carefully. Determine if they contain explicit facts that directly answer *all aspects* of the query. Do not rely on making significant inferences or assumptions. The answer should be explicitly stated or very easily pieced together from the text. +4. **Sufficiency Judgment**: + * If the query asked for specific details (names, dates, locations, numbers), and no exact details are provided, label as **insufficient**. + * If all parts of the query are directly and explicitly answered, the documents are **sufficient**. + * If any significant part of the query is not answered, the documents are **insufficient**. + * If you are uncertain, err on the side of caution and label it as **insufficient**. Do not guess. + * If query ask for frequencies, lists, such as "how many", "how often", the return should always be **insufficient** since you don't have complete data. +5. **Formulate Reasoning**: Based on your analysis, write a brief (1-2 sentences) explanation for your judgment. +6. **Identify Direct Evidence**: List the indices (0-based) of the documents that are clearly needed to answering the query. Ignore documents that are unrelated or tangential. + +**Output Format** (strict JSON, no additional text before or after the JSON block): +{{ + "is_sufficient": true or false, + "reasoning": "Brief, clear explanation for your decision (1-2 sentences).", + "indices": [index1, index2, ...] +}} + +**Examples**: + +Example 1 (Sufficient): +Query: "What's Alice's hobbies?" +Documents: + [0][2022-01-23T14:01:35] "Alice mentioned she loves painting. She spends weekends at art galleries." + [1][2022-01-23T14:05:40] "Alice's work involves creative projects. She loves to travel off work." +Output: +{{ + "is_sufficient": true, + "reasoning": "Document 0 and 1 explicitly states that Alice's hobbies are painting and traveling.", + "indices": [0, 1] +}} + +Example 2 (Insufficient - Missing Detail): +Query: "Where did Bob go on vacation last year?" +Documents: + [0][2024-01-23T14:01:35] "Bob talked about his work project deadline around June." + [1][2024-01-23T14:05:40] "Bob likes to travel and explore new places. He recently came back from a trip." +Output: +{{ + "is_sufficient": false, + "reasoning": "The documents mention Bob likes to travel and recently took a trip, but do not specify where he went or if it was last year.", + "indices": [1] +}} + +Example 3 (Insufficient - Tangential Information): +Query: "What are the specs of the new 'Galaxy Z' phone?" +Documents: + [0][2024-01-23T14:01:35] "The Galaxy Z is rumored to be released next month." + [1][2024-02-23T14:01:35] "Tech enthusiasts are excited about the upcoming Galaxy Z launch event." +Output: +{{ + "is_sufficient": false, + "reasoning": "The documents mention the phone's upcoming release but provide no specific technical specifications.", + "indices": [] +}} + +Now evaluate: +""" + +def rrf( + lists: list[list[Any]], + weights: list[float] | None = None, + k: int = 50, +) -> list[tuple[Any, float]]: + """ + Reciprocal Rank Fusion (RRF) implementation. + Args: + lists (list[list[str]]): List of ranked lists to fuse. + weights (list[float] | None): Weights for each ranked list. If None, equal weights are used. + k (int): Constant to control the influence of rank. Smaller k gives more weight to higher ranks. + Returns: + list[tuple[str, float]]: List of tuples containing item and its fused score, sorted by score in descending order. + """ + if not lists: + return [] + + n = len(lists) + if weights is None: + weights = [1.0] * n + elif len(weights) != n: + raise ValueError(f"RRF: weights length {len(weights)} must equal number of lists {n}") + + scores: dict[str, float] = defaultdict(float) + + for w, L in zip(weights, lists): + for rank, _id in enumerate(L, start=1): + if w > 0: + scores[_id] += w * (1.0 / (k + rank)) + + return sorted(scores.items(), key=lambda x: x[1], reverse=True) + +class KnowledgeGraph: + """ + Knowledge Graph object should be instantiated once per session_id. + """ + def __init__( + self, + model_name: str, + model: AsyncOpenAI, + embedder: OpenAIEmbedder, + store: Neo4jVectorGraphStore, + reranker: AmazonBedrockReranker = None, + ): + # Resources + self._embedder = embedder + self._store = store + self._model = model + self._model_name = model_name + self._reranker = reranker + + # Local caches + self._entity_node_map = {} + self._entity_edge_map = {} + self._episode_cluster_edge_map = {} + self._episode_batch: list[Node] = [] + self._processed_triple_texts = set() + + # Constants + self._embedding_batch_size = 100 + + # Performance metrics + self._perf_episode_grouping_time = 0.0 + self._perf_episode_summary_time = 0.0 + self._perf_entity_search_time = 0.0 + self._perf_relation_extraction_time = 0.0 + self._perf_embedding_time = 0.0 + self._perf_node_creation_time = 0.0 + self._perf_edge_creation_time = 0.0 + self._entity_node_created = 0 + self._relation_edge_created = 0 + self._episode_cluster_node_created = 0 + + # Setup edge_extractor + self._edge_extractor = setup_dialogue_relation_extractor() + self._extract_sem = asyncio.Semaphore(200) + + def setup_dialogue_relation_extractor(self): + """Setup DSPy-optimized relation extractor for text with coreference resolution""" + + class EdgeExtraction(dspy.Signature): + """ + Extract semantic relationships from input text with advanced coreference resolution. Return as JSON array. + + **Guidelines**: + - Extract meaningful semantic relationships between entities in the text + - Resolve complex coreferences in multi-speaker conversations: + a) Identify all pronouns, definite articles, and vague references (the, this, that, it, they, etc.) + b) Trace each reference back to its most specific antecedent in the conversation + c) Replace vague references with the most specific, concrete entity name possible + d) For events/actions, create descriptive names that capture the essence + - Handle speaker-specific references: + a) Resolve possessive pronouns and contextual references to specific person names when possible + b) Identify relationships between speakers and mentioned entities + c) Track entity mentions across multiple speakers + - Pay special attention to temporal relationships and time information: + a) Extract when events occurred (dates, times, periods) + b) Preserve temporal context in relationships + c) Include time information in relation, do not include in subject/object + - Use context clues to determine the most appropriate entity name + - Do not emit duplicate or semantically redundant relationships + - Return as JSON array of objects with "subject", "relation", "object" fields + - DO NOT include any duplicate relationships + - Make sure the resulting JSON is valid and parsable + - No comments or '//' or explainations in the output JSON, only the JSON array + - The JSON array should not exceed 16000 characters + + **Coreference Resolution Strategy**: + - Replace any vague or ambiguous references with specific entity names based on conversation context + - For events and actions, create descriptive names that capture what actually happened + - For people and organizations, use their full, proper names when available + - Always preserve temporal information and context + - Handle speaker-specific references and possessive pronouns + + **Output Format**(strict JSON parsable, no additional text): + [{ + "subject": "Entity1", + "relation": "relation1", + "object": "Entity2" + }, + { + "subject": "Entity1", + "relation": "relation2", + "object": "Entity2" + }] + + Before returning the result, do the following step by step: + 1. Remove duplicated or semantically redundant relationships + 2. The resulting relationship dictionaries should have "subject", "relation", "object" fields, if not, fix them. + 2. Parse the resulting JSON to make sure it is valid and parsable. If not parsable, redo the extraction. + 3. Check there is no comments or '//' or explainations in the output JSON, only the JSON array + 4. The JSON array should not exceed 16000 characters, if too long, remove less important relationships. + """ + content: str = dspy.InputField(desc="Input text to extract relationships from") + relations: str = dspy.OutputField(desc="JSON array of relationships with subject, relation, object") + + # Create the edge predictor + edge_predictor = dspy.Predict(EdgeExtraction) + return edge_predictor + + async def extract_dialogue_relations( + self, + text: str + ) -> List[Tuple[str, str, str, List[int]]]: + """Extract relations from dialogue text using DSPy-optimized method with coreference resolution""" + retry = 2 + cache = True + while retry >= 0: + lm = dspy.LM( + model="gpt-4.1-mini", + api_key=os.getenv("OPENAI_API_KEY"), + temperature=0.2, + max_tokens=16000, + cache=cache, + ) + + # Extract relations + with dspy.settings.context(lm=lm): + async with self._extract_sem: + result = await self._edge_extractor.acall( + content=text, + ) + try: + for i in range(len(result.relations)): + if "//" in result.relations[i]: + print(f"WARNING: Found '//' in relation extraction result line: {result.relations[i]}") + # Remove comments in any line + index = result.relations[i].find("//") + result.relations[i] = result.relations[i][:index] + if result.relations[-1] != "]": + # Remove from last "}," and append "]" + index = result.relations.rfind("},") + result.relations = result.relations[:index + 1] + "]" + print(f"WARNING: Fixed truncated JSON array in RE from Text:\n{text}\nResult:\n{result.relations}") + + # Parse JSON response + relations = json.loads(result.relations) + triples = [] + seen = set() + for rel in relations: + if all(key in rel for key in ['subject', 'relation', 'object']): + triple_text = f"{rel['subject']} {rel['relation']} {rel['object']}" + # Dedup result before returning + if triple_text in seen: + continue + triples.append((rel['subject'], rel['relation'], rel['object'])) + seen.add(triple_text) + return triples + except Exception as e: + print(f"Warning: Failed to parse relations from text:({text}), relation result({result.relations}): {str(e)}") + if retry == 0: + raise e + retry -= 1 + cache = False + + + async def get_episode_groups( + self, + episodes: List[Node], + ) -> tuple[list[list[Node]], list[Node]]: + """Group episodes into clusters based on content similarity""" + episode_content_str = "" + for i, e in enumerate(episodes): + episode_content_str += f"[{i}][{e.properties['timestamp']}] {e.properties['content']}\n" + episode_group_prompt = EPISODE_GROUP_PROMPT.format(episodes=episode_content_str) + res = await self._model.responses.create( + model=self._model_name, + max_output_tokens=4096, + temperature=0.0, + input=[{"role": "user", "content": episode_group_prompt}], + ) + try: + json_parsable_str = "" + for line in res.output_text.split("\n"): + if line == "```json" or line == "```": + continue + json_parsable_str += line + "\n" + response = json.loads(json_parsable_str) + except Exception as e: + print(f"WARNING: Failed to parse episode grouping response JSON: {json_parsable_str}, error: {e}") + raise e + + # If last two indices forms a group or forms singleton, do not group them and pass to next grouping + filtered_groups = [] + left_over_episodes = [] + for group in response.get("groups", []): + add = True + last_idx = len(episodes) - 1 + second_last_idx = len(episodes) - 2 + if len(group) == 1: + if last_idx in group or second_last_idx in group: + add = False + elif len(group) == 2: + if last_idx in group and second_last_idx in group: + add = False + if add: + filtered_groups.append(group) + else: + for idx in group: + left_over_episodes.append(episodes[idx]) + + groups = [] + for group in filtered_groups: + cur_group = [] + for idx in group: + if idx < 0 or idx >= len(episodes): + print(f"WARNING: episode grouping returned invalid episode index: {idx}, len(episodes)={len(episodes)}") + continue + cur_group.append(episodes[idx]) + groups.append(cur_group) + + return groups, left_over_episodes + + async def get_episode_group_summary( + self, + episodes: List[Node], + ) -> str: + """Generate a concise summary for a group of related episodes""" + episode_content_str = "" + for e in episodes: + episode_content_str += f"[{e.properties['timestamp']}] {e.properties['content']}\n" + episode_summary_prompt = EPISODE_GROUP_SUMMARY_PROMPT.format(episodes=episode_content_str) + res = await self._model.responses.create( + model=self._model_name, + max_output_tokens=512, + temperature=0.0, + input=[{"role": "user", "content": episode_summary_prompt}], + ) + summary = res.output_text.strip() + return summary + + async def get_related_episode_cluster( + self, + uuid: UUID, + session_id: str, + index_search_label: str, + ) -> set[Node]: + return await self._store.search_related_nodes( + node_uuid=uuid, + allowed_relations={"INCLUDE"}, + find_sources=True, + find_targets=False, + limit=None, + required_labels={"EpisodeCluster"}, + required_properties={"session_id": session_id}, + ) + + async def batch_ingest_embed( + self, + content: list[str], + ) -> dict[str, list[float]]: + # Dedup and remove empty strings + tmp_set = set() + dedup_content = [] + for c in content: + if c == "": + continue + if c in tmp_set: + continue + dedup_content.append(c) + tmp_set.add(c) + + num_batch = self._embedding_batch_size + if len(dedup_content) < num_batch: + num_batch = len(dedup_content) if len(dedup_content) > 0 else 1 + embeddings = [] + for j in range(0, len(dedup_content), num_batch): + if j + num_batch > len(dedup_content): + num_batch = len(dedup_content) - j + batch = [c for c in dedup_content[j:j+num_batch]] + batch_embeddings = await self._embedder.ingest_embed(batch, max_attempts=3) + embeddings.extend(batch_embeddings) + return { + c: emb for c, emb in zip(dedup_content, embeddings) + } + + async def generate_edges_nodes( + self, + episode_cluster: Node, + source: str, + relation: str, + target: str, + triple_text: str, + source_embedding: list[float], + target_embedding: list[float], + triple_text_embedding: list[float], + ) -> tuple[list[Node], list[Edge]]: + nodes: list[Node] = [] + edges: list[Edge] = [] + episode_cluster_edges: list[Edge] = [] + t = time.perf_counter() + # 1. Search for existing source and target entity nodes + # TODO: currently assume source/target entity in same batch of episodes are the same + if source not in self._entity_node_map: + # TODO: Optimize by batch processing with asyncio.gather + search_node = await self._store.search_similar_nodes( + query_embedding=source_embedding, + embedding_property_name="name_embedding", + limit=2, + required_labels={"Entity"}, + required_properties={"session_id": episode_cluster.properties["session_id"]}, + ) + for n in search_node: + if source == n.properties["name"]: + self._entity_node_map[source] = n.uuid + break + + if target != "" and target not in self._entity_node_map: + search_node = await self._store.search_similar_nodes( + query_embedding=target_embedding, + embedding_property_name="name_embedding", + limit=2, + required_labels={"Entity"}, + required_properties={"session_id": episode_cluster.properties["session_id"]}, + ) + for n in search_node: + if target == n.properties["name"]: + self._entity_node_map[target] = n.uuid + break + self._perf_entity_search_time += time.perf_counter() - t + + # 2. Create source and target entity nodes if not exist + if source not in self._entity_node_map: + node = Node( + uuid=uuid4(), + labels={"Entity"}, + properties={ + "name": source, + "session_id": episode_cluster.properties["session_id"], + "name_embedding": source_embedding, + }, + ) + nodes.append(node) + self._entity_node_map[source] = node.uuid + # print(f"Created source node: {source} with UUID {source_node.uuid}") + + if target != "" and target not in self._entity_node_map: + node = Node( + uuid=uuid4(), + labels={"Entity"}, + properties={ + "name": target, + "session_id": episode_cluster.properties["session_id"], + "name_embedding": target_embedding, + }, + ) + nodes.append(node) + self._entity_node_map[target] = node.uuid + # print(f"Created target node: {target} with UUID {target_node.uuid}") + + # 3. Create edge from source to target entity. This edge does not contain any data. + # This edge ensures related episode clusters can be connected via entity nodes. + edge_id = source + "-" + target + if target != "" and edge_id not in self._entity_edge_map: + edge = Edge( + uuid=uuid4(), + source_uuid=self._entity_node_map[source], + target_uuid=self._entity_node_map[target], + relation="RELATED_TO", + properties={ + "session_id": episode_cluster.properties["session_id"], + }, + ) + edges.append(edge) + self._entity_edge_map[edge_id] = edge.uuid + + # 4. Create edge from episode cluster to source entity + e_to_source_id = str(episode_cluster.uuid) + "-" + triple_text + if e_to_source_id not in self._episode_cluster_edge_map: + # Add edge from episode cluster to source node + edges.append( + Edge( + uuid=uuid4(), + source_uuid=episode_cluster.uuid, + target_uuid=self._entity_node_map[source], + relation="HAS_RELATION", + # No 'timestamp' field because we don't know which exact episode(s) the relation comes from + properties={ + "triple_text": triple_text, + "triple_text_embedding": triple_text_embedding, + "session_id": episode_cluster.properties["session_id"], + }, + ) + ) + self._episode_cluster_edge_map[e_to_source_id] = True + self._relation_edge_created += 1 + + return nodes, edges + + # Have to be locked to avoid duplicate edge/node because the local caches are not async safe. + @async_locked + async def add_episode_bulk(self, + episodes: List[Node], + flush: bool = False, + ) -> None: + if len(episodes) == 0: + return + + # All episodes must belong to the same session + session_id = episodes[0].properties["session_id"] + for e in episodes: + if e.properties["session_id"] != session_id: + raise ValueError(f"All episodes must belong to the same session for bult adding, found episode with session_id {e.properties['session_id']} instead of {session_id}") + + cluster_nodes: list[Node] = [] + cluster_edges: list[Edge] = [] + episode_groups = [] + + # 1. Chunking episodes into batches and generating episode clusters + e_queue = deque(episodes) + while len(e_queue) > 0: + # i). Fill episode_batch to 10 + if len(self._episode_batch) < 10: + self._episode_batch.append(e_queue.popleft()) + # Leftover episodes less than 10, wait for next bulk add if not flushing + if len(self._episode_batch) < 10: + if flush and len(e_queue) == 0: + # continue to process leftover episodes + pass + # Otherwise, continue to accumulate episodes until reached batch size + continue + + # ii). Use LLM to generate episode cluster that groups epsisodes + t = time.perf_counter() + groups, left_over_episodes = await self.get_episode_groups(self._episode_batch) + self._perf_episode_grouping_time += time.perf_counter() - t + + # Update episode_batch with leftover episodes + self._episode_batch = left_over_episodes + + # iii). Generate summary and create episode cluster nodes + for group in groups: + episode_groups.append(group) + t = time.perf_counter() + summary = await self.get_episode_group_summary(group) + self._perf_episode_summary_time += time.perf_counter() - t + cluster_node = Node( + uuid=uuid4(), + labels={"EpisodeCluster"}, + properties={ + "summary": summary, + "session_id": group[0].properties["session_id"], + }, + ) + cluster_nodes.append(cluster_node) + self._episode_cluster_node_created += 1 + # Create edges from cluster node to episodes + for episode in group: + edge = Edge( + uuid=uuid4(), + source_uuid=cluster_node.uuid, + target_uuid=episode.uuid, + relation="INCLUDES", + properties={ + "session_id": episode.properties["session_id"], + "timestamp": episode.properties["timestamp"], + }, + ) + cluster_edges.append(edge) + + if len(episode_groups) == 0: + return + + if len(episode_groups) != len(cluster_nodes): + raise ValueError(f"Episode groups length {len(episode_groups)} does not match cluster nodes length {len(cluster_nodes)}") + + # 2. Batch extracting relations from each group of episodes + relation_extraction_tasks = [] + for group in episode_groups: + group_content = "" + for e in group: + group_content += f"[{e.properties['timestamp']}] {e.properties['content']}\n" + relation_extraction_tasks.append(self.extract_dialogue_relations(group_content)) + + t = time.perf_counter() + relations_list = await asyncio.gather(*relation_extraction_tasks) + self._perf_relation_extraction_time += time.perf_counter() - t + + if len(relations_list) != len(cluster_nodes): + raise ValueError(f"Relations list length {len(relations_list)} does not match cluster nodes length {len(cluster_nodes)}") + + # 3. Get embeddings for triples, sources, targes, and episode cluster summarys + sources = [] + targets = [] + triple_texts = [] + cluster_summaries = [] + for relations in relations_list: + for source, relation, target in relations: + triple_text = f"{source} {relation} {target}" if target != "" else f"{source} {relation}" + sources.append(source) + targets.append(target) + triple_texts.append(triple_text) + + for n in cluster_nodes: + cluster_summaries.append(n.properties["summary"]) + + t = time.perf_counter() + sources_embedding = await self.batch_ingest_embed(sources) + targets_embedding = await self.batch_ingest_embed(targets) + triple_texts_embedding = await self.batch_ingest_embed(triple_texts) + cluster_summaries_embedding = await self.batch_ingest_embed(cluster_summaries) + self._perf_embedding_time += time.perf_counter() - t + + # 4. Generate entity edges and nodes, assign embeddings + nodes = [] + edges = [] + for relations, episode_cluster in zip(relations_list, cluster_nodes): + # Assign summary embedding to episode cluster + episode_cluster.properties["summary_embedding"] = cluster_summaries_embedding[episode_cluster.properties["summary"]] + for source, relation, target in relations: + triple_text = f"{source} {relation} {target}" if target != "" else f"{source} {relation}" + if triple_text in self._processed_triple_texts: + continue + self._processed_triple_texts.add(triple_text) + + n_res, e_res = await self.generate_edges_nodes( + episode_cluster=episode_cluster, + source=source, + relation=relation, + target=target, + triple_text=triple_text, + source_embedding=sources_embedding[source], + target_embedding=targets_embedding[target] if target != "" else [], + triple_text_embedding=triple_texts_embedding[triple_text], + ) + nodes.extend(n_res) + edges.extend(e_res) + + # 5. Bulk add all nodes and edges to graph store + t = time.perf_counter() + await self._store.add_nodes(episodes) + await self._store.add_nodes(cluster_nodes) + await self._store.add_nodes(nodes) + self._entity_node_created += len(nodes) + self._episode_cluster_node_created += len(cluster_nodes) + self._perf_node_creation_time = time.perf_counter() - t + + t = time.perf_counter() + await self._store.add_edges(cluster_edges) + await self._store.add_edges(edges) + self._perf_edge_creation_time = time.perf_counter() - t + + def print_ingest_perf_matrix(self): + print(f"Ingestion Performance Matrics:") + print(f" Episode Grouping Time: {self._perf_episode_grouping_time:.2f} seconds") + print(f" Episode Summary Time: {self._perf_episode_summary_time:.2f} seconds") + print(f" Entity Search Time: {self._perf_entity_search_time:.2f} seconds") + print(f" Relation Extraction Time: {self._perf_relation_extraction_time:.2f} seconds") + print(f" Embedding Time: {self._perf_embedding_time:.2f} seconds") + print(f" Node Creation Time: {self._perf_node_creation_time:.2f} seconds") + print(f" Edge Creation Time: {self._perf_edge_creation_time:.2f} seconds") + print(f" Entity Nodes Created: {self._entity_node_created}") + print(f" Relation Edges Created: {self._relation_edge_created}") + print(f" Episode Cluster Nodes Created: {self._episode_cluster_node_created}") + + async def cohere_rerank( + self, + items: list[Node] | list[Edge], + score_threshold: float, + query: str, + limit: int | None, + ) -> list[tuple[Node | Edge, float]]: + if len(items) == 0: + return [] + content_list = [] + if isinstance(items[0], Node): + for e in items: + if 'content' in e.properties: + content_list.append(e.properties['content']) + elif 'summary' in e.properties: + content_list.append(e.properties['summary']) + elif isinstance(items[0], Edge): + for e in items: + content_list.append(e.properties['triple_text']) + else: + raise Exception(f"Unknown item type for reranking: {type(items)}") + + num_max = 1000 + processed = 0 + scores = [] + while processed < len(content_list): + batch_contents = content_list[processed:processed+num_max] + success = False + max_retry = 60 + batch_scores = [] + while not success: + try: + batch_scores = await self._reranker.score(query, batch_contents) + success = True + except Exception as e: + max_retry -= 1 + if max_retry == 0: + print(f"ERROR: Reranker failed after maximum retries.") + raise e + if "ThrottlingException" in str(e): + print(f"WARNING: Reranker throttling exception, retrying after 60 second...") + time.sleep(60) + else: + raise e + scores.extend(batch_scores) + processed += len(batch_contents) + + + scored = sorted( + zip(items, scores), + key=lambda x: x[1], # sort by score + reverse=True # highest score first + ) + + result = [] + for e, s in scored: + if s < score_threshold and limit is not None and len(result) >= limit: + break + if limit is not None and len(result) >= limit: + break + result.append((e, s)) + + return result + + async def check_sufficiency( + self, + query: str, + episodes: list[Node], + ) -> tuple[dict[str, Any], list[Node], int, int]: + episode_content = "" + for idx, e in enumerate(episodes): + episode_content += f"[{idx}][{e.properties['timestamp']}] {e.properties['content']}\n" + + sufficient_check_prompt = SUFFICIENCY_CHECK_PROMPT.format( + query=query, + retrieved_episodes=episode_content, + ) + res = await self._model.responses.create( + model=self._model_name, + max_output_tokens=4096, + temperature=0.0, + # reasoning={"effort": "none"}, + input=[{"role": "user", "content": sufficient_check_prompt}], + ) + json_parsable_str = "" + try: + for line in res.output_text.split("\n"): + if line == "```json" or line == "```": + continue + if line.strip().startswith("\"reasoning\":"): + start_pos = line.find(":") + # Find first and last quote + first_quote = line.find("\"", start_pos) + last_quote = line.rfind("\"") + # For any extra quote in between without epsace, replace with single quote + # Find quotes in between + for i in range(first_quote+1, last_quote): + if line[i] == "\"": + # Check if escaped + if i > 0 and line[i-1] == "\\": + continue + # Add the escape char + line = line[:i] + "\\" + line[i:] + # Move i and last_quote forward by 1 + i += 1 + last_quote += 1 + json_parsable_str += line + "\n" + response = json.loads(json_parsable_str) + except Exception as e: + print(f"WARNING: Failed to parse sufficiency check response JSON: {res.output_text}\nFinal string used: {json_parsable_str} error: {e}") + response = {"is_sufficient": False} + + res_episodes = [] + for idx in response.get("indices", []): + if idx < 0 or idx >= len(episodes): + print(f"WARNING: sufficiency check returned invalid episode index: {idx}, len(episodes)={len(episodes)}") + continue + res_episodes.append(episodes[idx]) + + return response, res_episodes, res.usage.input_tokens, res.usage.output_tokens + + async def check_sufficiency_batch( + self, + episodes: list[Node], + possible_relevant_episodes: Collection[Node], + query: str, + ) -> tuple[bool, list[Node], set[Node], int, int]: + input_tokens = 0 + output_tokens = 0 + episode_batch = [] + possible_relevant = set(possible_relevant_episodes) + for e in episodes: + if e in possible_relevant: + continue + episode_batch.append(e) + if len(episode_batch) >= 10 or e == episodes[-1]: + res, suff_episodes, it, ot = await self.check_sufficiency(query, episode_batch) + input_tokens += it + output_tokens += ot + if res['is_sufficient']: + sorted_episodes = sorted( + suff_episodes, + key=lambda e: (e.properties.get('timestamp') is None, + e.properties.get('timestamp')) + ) + reasoning_str = "Inputs:\n" + for idx, e in enumerate(episode_batch): + reasoning_str += f"[{idx}][{e.properties['timestamp']}] {e.properties['content']}\n" + reasoning_str += f"Reasoning: {res['reasoning']}\n" + return True, sorted_episodes, possible_relevant, reasoning_str, input_tokens, output_tokens + episode_batch =[] + possible_relevant.update(suff_episodes) + return False, [], possible_relevant, "", input_tokens, output_tokens + + async def relation_and_summary_search( + self, + query: str, + session_id: str, + limit: int = 10, + perf_matrix: dict[str, Any] = {}, + ): + # 1. Similarity and fulltext search the query on relation edges + t = time.perf_counter() + q_embedding = (await self._embedder.search_embed([query], max_attempts=3))[0] + perf_matrix["embedding_time"] += time.perf_counter() - t + + t = time.perf_counter() + edges, res_ec_nodes = await self._store.search_similar_edges( + query_text=query, + query_embedding=q_embedding, + embedding_property_name="triple_text_embedding", + limit=max(5, limit * 3), + allowed_relations={"HAS_RELATION"}, + required_properties={"session_id": session_id}, + ) + perf_matrix["edge_search_time"] += time.perf_counter() - t + + ec_node_map = { + n.uuid: n for n in res_ec_nodes + } + # 2. Rerank and get top edges + t = time.perf_counter() + cohere_res = await self.cohere_rerank(edges, score_threshold=0.0, query=query, limit=limit) + perf_matrix["rerank_time"] += time.perf_counter() - t + + added_ec_uuids = set() + edge_search_ec_nodes = [] + for e, _ in cohere_res: + # Source node of relation edge is always the episode cluster + if e.source_uuid not in added_ec_uuids: + edge_search_ec_nodes.append(ec_node_map[e.source_uuid]) + added_ec_uuids.add(e.source_uuid) + + # TODO: Extracr entities from query and do entity node search? + + # # 3. Similarity and fulltext search the query on episode clusters + # t = time.perf_counter() + # res_ec_nodes = await self._store.search_similar_nodes( + # query_embedding=q_embedding, + # embedding_property_name="summary_embedding", + # limit=max(5, limit), + # required_labels={"EpisodeCluster"}, + # required_properties={"session_id": session_id}, + # ) + # perf_matrix["episode_cluster_node_search_time"] += time.perf_counter() - t + + # # 4. Rerank the directly searched episode clusters, then RRF rerank with edge searched episode clusters + # t = time.perf_counter() + # cohere_res = await self.cohere_rerank(res_ec_nodes, score_threshold=0.0, query=query, limit=limit) + # perf_matrix["rerank_time"] += time.perf_counter() - t + + # vec_search_ec = [e for e, _ in cohere_res] + + # # vec_search_ec = [e for e, _ in cohere_res] + # fused = rrf([edge_search_ec_nodes, vec_search_ec], k=50) + + # # RRF returns unique items, return the result directly + # result_episode_clusters = [e for e, _ in fused] + return edge_search_ec_nodes + + async def get_relation_edges_and_episode_clusters( + self, + entity_node_uuids: list[UUID], + session_id: str, + ) -> tuple[list[Node], list[Edge]]: + return await self._store.search_related_nodes_edges_batch( + node_uuids=entity_node_uuids, + index_search_label=":Entity", + allowed_relations={"HAS_RELATION"}, + find_sources=True, + find_targets=False, + limit=None, + required_labels={"EpisodeCluster"}, + required_properties={"session_id": session_id}, + ) + + async def get_included_episodes_in_order_from_clusters( + self, + episode_cluster_uuids: list[UUID], + session_id: str, + ) -> list[Node]: + related_episodes = [] + for uuid in episode_cluster_uuids: + related_episodes.extend( + await self._store.search_related_nodes( + node_uuid=uuid, + index_search_label=":EpisodeCluster", + allowed_relations={"INCLUDES"}, + find_sources=False, + find_targets=True, + limit=None, + required_labels={"Episode"}, + required_properties={"session_id": session_id}, + ) + ) + return related_episodes + + def init_perf_matrix(self) -> dict[str, Any]: + return { + "query": "", + "msg": "", + "embedding_time": 0.0, + "entity_extraction_time": 0.0, + "edge_search_time": 0.0, + "entity_node_search_time": 0.0, + "episode_cluster_node_search_time": 0.0, + "related_node_search_time": 0.0, + "sufficiency_check_time": 0.0, + "rerank_time": 0.0, + "num_sufficiency_checks": 0, + "num_sufficiency_check_episodes": 0, + "num_bfs_iteration": 0, + "num_llm_input_tokens": 0, + "num_llm_output_tokens": 0, + "total_time": 0.0, + "total_return_episodes": 0, + } + + def print_search_perf_matrix( + self, + perf_matrix: dict[str, Any], + ): + print(f"Search Performance Matrics:") + for key, value in perf_matrix.items(): + if type(value) == float: + value = f"{value:.2f}" + print(f" {key}: {value}") + print("=================================================================\n") + + async def search( + self, + query: str, + possible_episodes: list[Node], + session_id: int, + limit: int = 10 + ) -> tuple[list[str], dict[str, Any], bool]: + entity_node_local_cache = set() + episode_cluster_local_cache = set() + + possible_relevant_episodes = set(possible_episodes) + + # Initialize performance matrix for current search + perf_matrix = self.init_perf_matrix() + perf_matrix["query"] = query + + search_start = time.perf_counter() + + # 1. Get initial episode clusters by searching on relation triples and summaries + episode_clusters = await self.relation_and_summary_search(query, session_id, limit, perf_matrix) + if len(episode_clusters) == 0: + perf_matrix["msg"] += "No related episodes clusters found from initial search.\n" + perf_matrix["total_return_episodes"] = 0 + perf_matrix["total_time"] = time.perf_counter() - search_start + self.print_search_perf_matrix(perf_matrix) + return [], perf_matrix, False + + episode_cluster_uuids = [e.uuid for e in episode_clusters] + episode_cluster_local_cache.update(episode_cluster_uuids) + + # 2. Get related episodes from the episode clusters, order by episode cluster orders. + # There should be no duplicates since each episode only belongs to one episode cluster. + t = time.perf_counter() + related_episodes = await self.get_included_episodes_in_order_from_clusters( + episode_cluster_uuids, + session_id, + ) + perf_matrix["related_node_search_time"] += time.perf_counter() - t + + perf_matrix["msg"] += f"Found {len(episode_clusters)} initial related episode clusters and expands to {len(related_episodes)} episodes.\n" + + # DEBUG: check duplication + if len(related_episodes) != len(set([e.uuid for e in related_episodes])): + raise ValueError(f"Related episodes from episode clusters contain duplicates, total {len(related_episodes)} vs unique {len(set([e.uuid for e in related_episodes]))}") + + # 3. Rerank the related episodes + t = time.perf_counter() + cohere_res = await self.cohere_rerank(related_episodes, score_threshold=0.0, query=query, limit=None) + perf_matrix["rerank_time"] += time.perf_counter() - t + related_episodes = [e for e, _ in cohere_res] + + # 3.Check sufficiency, on all related episodes, using 'limit' as batch size + t = time.perf_counter() + is_sufficient, sorted_suff_episodes, possible_relevant, _, itoken, otoken = await self.check_sufficiency_batch( + related_episodes, + possible_relevant_episodes, + query, + ) + perf_matrix["num_llm_input_tokens"] += itoken + perf_matrix["num_llm_output_tokens"] += otoken + perf_matrix["sufficiency_check_time"] += time.perf_counter() - t + perf_matrix["num_sufficiency_checks"] += len(related_episodes) // limit + (1 if len(related_episodes) % limit != 0 else 0) + perf_matrix["num_sufficiency_check_episodes"] += len(related_episodes) + if is_sufficient: + if len(sorted_suff_episodes) > limit: + perf_matrix["msg"] += f"Number of sufficient episodes exceed limit: {limit}/{len(suff_episodes)}. Result truncated.\n" + # Rerank sufficient episodes to get top 'limit' episodes + t = time.perf_counter() + cohere_res = await self.cohere_rerank(sorted_suff_episodes, score_threshold=0.0, query=query, limit=limit) + perf_matrix["rerank_time"] += time.perf_counter() - t + sorted_suff_episodes = sorted( + list([e for e, _ in cohere_res]), + key=lambda e: (e.properties.get('timestamp') is None, + e.properties.get('timestamp')) + ) + perf_matrix["msg"] += f"Sufficient from initial retrieved episodes\n" + perf_matrix["total_return_episodes"] = len(sorted_suff_episodes) + perf_matrix["total_time"] = time.perf_counter() - search_start + self.print_search_perf_matrix(perf_matrix) + return sorted_suff_episodes, perf_matrix, is_sufficient + + possible_relevant_episodes.update(possible_relevant) + + # 4. Get total number of episode cluster in current session + ec_all = await self._store.search_matching_nodes( + limit=None, + required_labels={"EpisodeCluster"}, + required_properties={"session_id": session_id}, + ) + + # 5. Start BFS, expanding from initial episode clusters + num_all_ec = len(ec_all) + episode_cluster_queue = deque(episode_clusters) + while len(episode_cluster_queue) > 0: + if len(episode_cluster_local_cache) >= num_all_ec: + if (len(episode_cluster_local_cache) > num_all_ec): + perf_matrix["msg"] += f"WARNING: total related episode cluster {len(episode_cluster_local_cache)} greater than session total {num_all_ec}\n" + break + + perf_matrix["num_bfs_iteration"] += 1 + episode_cluster = episode_cluster_queue.popleft() + + # perf_matrix["msg"] += f" DEBUG: curr iter: {perf_matrix["num_bfs_iteration"]}, queue size: {len(episode_cluster_queue)}, possible relevant episodes size: {len(possible_relevant_episodes)}\n" + + # i). Get related entity nodes + t = time.perf_counter() + res = await self._store.search_related_nodes( + node_uuid=episode_cluster.uuid, + index_search_label=":EpisodeCluster", + allowed_relations={"HAS_RELATION"}, + find_sources=False, + find_targets=True, + limit=None, + required_labels={"Entity"}, + required_properties={"session_id": session_id}, + ) + perf_matrix["related_node_search_time"] += time.perf_counter() - t + entity_nodes_uuids = [] + for n in res: + if n.uuid in entity_node_local_cache: + continue + entity_nodes_uuids.append(n.uuid) + entity_node_local_cache.add(n.uuid) + + # perf_matrix["msg"] += f" DEBUG: Found {len(entity_nodes_uuids)} new source entity nodes\n" + if len(entity_nodes_uuids) == 0: + continue + + # ii). Get all connected episode clusters and relation edges from the entity nodes above. + t = time.perf_counter() + res_ec, res_edges = await self.get_relation_edges_and_episode_clusters( + entity_nodes_uuids, + session_id, + ) + perf_matrix["related_node_search_time"] += time.perf_counter() - t + # Dedup search result(do not move this inside get_relation_edges_and_episode_clusters() to + # be more clear). Notice need to dedup relation edges first. The order matters because + # we append nodes to episode_cluster_local_cache when deduping new episode clusters. + direct_relation_edges = [] + for re in res_edges: + if re.source_uuid in episode_cluster_local_cache: + continue + direct_relation_edges.append(re) + + direct_episode_clusters = [] + for e in res_ec: + if e.uuid in episode_cluster_local_cache: + continue + direct_episode_clusters.append(e) + episode_cluster_local_cache.add(e.uuid) + + # perf_matrix["msg"] += f" DEBUG: Get {len(direct_episode_clusters)} direct episode clusters and {len(direct_relation_edges)} direct relation edges\n" + + # iii). Get connected Entity node via RELATED_TO edges from the entity nodes above + t = time.perf_counter() + res_nodes, _ = await self._store.search_related_nodes_edges_batch( + node_uuids=entity_nodes_uuids, + index_search_label=":Entity", + allowed_relations={"RELATED_TO"}, + find_sources=False, + find_targets=True, + limit=None, + required_labels={"Entity"}, + required_properties={"session_id": session_id}, + ) + perf_matrix["related_node_search_time"] += time.perf_counter() - t + + indirect_entity_nodes_uuids = [] + for n in res_nodes: + if n.uuid in entity_node_local_cache: + continue + indirect_entity_nodes_uuids.append(n.uuid) + entity_node_local_cache.add(n.uuid) + + # iv). Get the connected episode clusters and relation edges from the target entity nodes above. + #. These episode clusters are the '1st level of indirectly related' ecs of the episode + # cluster of current BFS iteration. + t = time.perf_counter() + res_ec, res_edges = await self.get_relation_edges_and_episode_clusters( + indirect_entity_nodes_uuids, + session_id, + ) + perf_matrix["related_node_search_time"] += time.perf_counter() - t + # Dedup search result, see comments above for details + indirect_relation_edges = [] + for re in res_edges: + if re.source_uuid in episode_cluster_local_cache: + continue + indirect_relation_edges.append(re) + + indirect_episode_clusters = [] + for e in res_ec: + if e.uuid in episode_cluster_local_cache: + continue + indirect_episode_clusters.append(e) + episode_cluster_local_cache.add(e.uuid) + + # perf_matrix["msg"] += f" DEBUG: Get {len(indirect_episode_clusters)} indirect episode clusters and {len(indirect_relation_edges)} indirect relation edges\n" + + # DEBUG: check relation edge duplication + if len(direct_relation_edges) + len(indirect_relation_edges) != len(set([e.uuid for e in direct_relation_edges + indirect_relation_edges])): + raise ValueError(f"Relation edges from unique episode cluster to entity nodes contain duplicates, total {len(direct_relation_edges) + len(indirect_relation_edges)} vs unique {len(set([e.uuid for e in direct_relation_edges + indirect_relation_edges]))}") + + new_ec_map = { + e.uuid: e for e in direct_episode_clusters + indirect_episode_clusters + } + + if len(new_ec_map) == 0: + continue + + # # Rrerank relations to get ranked episode clusters + # t = time.perf_counter() + # cohere_res = await self.cohere_rerank(direct_relation_edges + indirect_relation_edges, score_threshold=0.0, query=query, limit=None) + # perf_matrix["rerank_time"] += time.perf_counter() - t + + # relation_ranked_episode_clusters = [] + # added_ec_uuids = set() + # for e, _ in cohere_res: + # # Source node of relation edge is always the episode cluster + # if e.source_uuid in added_ec_uuids: + # continue + # added_ec_uuids.add(e.source_uuid) + # relation_ranked_episode_clusters.append(new_ec_map[e.source_uuid]) + + # # iv). Rerank new episode clusters based on summary + # t = time.perf_counter() + # cohere_res = await self.cohere_rerank(direct_episode_clusters + indirect_episode_clusters, 0, query, limit=None) + # perf_matrix["rerank_time"] += time.perf_counter() - t + # summary_reranked_episode_clusters = [e for e, _ in cohere_res] + + # # v). RRF fuse the two reranked episode cluster lists + # fused = rrf([relation_ranked_episode_clusters, summary_reranked_episode_clusters], k=50) + + # # RRF returns unique items + # episode_cluster_uuids = [e.uuid for e, _ in fused] + + # vi). Get related episodes from the fused episode clusters, order by episode cluster orders. + t = time.perf_counter() + related_episodes = await self.get_included_episodes_in_order_from_clusters( + [uuid for uuid in new_ec_map.keys()], + session_id, + ) + # perf_matrix["msg"] += f" DEBUG: BFS iteration {perf_matrix['num_bfs_iteration']}: Retrieved {len(related_episodes)} related episodes from {len(new_ec_map)} new episode clusters.\n" + perf_matrix["related_node_search_time"] += time.perf_counter() - t + + # 3. Rerank the related episodes + t = time.perf_counter() + cohere_res = await self.cohere_rerank(related_episodes, score_threshold=0.0, query=query, limit=None) + perf_matrix["rerank_time"] += time.perf_counter() - t + related_episodes = [e for e, _ in cohere_res] + + # perf_matrix["msg"] += f" DEBUG: BFS iteration {perf_matrix['num_bfs_iteration']}: Check sufficiency on {len(related_episodes)} new related episodes from fused episode clusters\n" + + # vii). Check sufficiency on all new related episodes + t = time.perf_counter() + is_sufficient, sorted_suff_episodes, possible_relevant, _, itoken, otoken = await self.check_sufficiency_batch( + related_episodes, + possible_relevant_episodes, + query, + ) + perf_matrix["num_llm_input_tokens"] += itoken + perf_matrix["num_llm_output_tokens"] += otoken + perf_matrix["sufficiency_check_time"] += time.perf_counter() - t + perf_matrix["num_sufficiency_checks"] += len(related_episodes) // limit + (1 if len(related_episodes) % limit != 0 else 0) + perf_matrix["num_sufficiency_check_episodes"] += len(related_episodes) + if is_sufficient: + if len(sorted_suff_episodes) > limit: + perf_matrix["msg"] += f"Number of sufficient episodes exceed limit from BFS retrival: {limit}/{len(suff_episodes)}. Result truncated.\n" + # Rerank sufficient episodes to get top 'limit' episodes + t = time.perf_counter() + cohere_res = await self.cohere_rerank(sorted_suff_episodes, score_threshold=0.0, query=query, limit=limit) + perf_matrix["rerank_time"] += time.perf_counter() - t + sorted_suff_episodes = sorted( + list([e for e, _ in cohere_res]), + key=lambda e: (e.properties.get('timestamp') is None, + e.properties.get('timestamp')) + ) + perf_matrix["msg"] += f"Sufficient from BFS retrieved episodes.\n" + perf_matrix["total_return_episodes"] = len(sorted_suff_episodes) + perf_matrix["total_time"] = time.perf_counter() - search_start + self.print_search_perf_matrix(perf_matrix) + return sorted_suff_episodes, perf_matrix, is_sufficient + possible_relevant_episodes.update(possible_relevant) + + # viii). Not sufficient, add to queue + episode_cluster_queue.extend(direct_episode_clusters + indirect_episode_clusters) + + # 5. BFS finished but still not sufficient, rerank possible_relevant_episodes if there are more than limit, otherwise return all + if len(possible_relevant_episodes) > limit: + print(f"BFS finished but not sufficient, rerank possible relevant episodes from {len(possible_relevant_episodes)} to {limit}.") + t = time.perf_counter() + cohere_res = await self.cohere_rerank(list(possible_relevant_episodes), 0, query, limit=limit) + possible_relevant_episodes = set([e for e, _ in cohere_res]) + + sorted_episodes = sorted( + list(possible_relevant_episodes), + key=lambda e: (e.properties.get('timestamp') is None, + e.properties.get('timestamp')) + ) + perf_matrix["msg"] += f"BFS finished but not sufficient, return possible relevant episodes.\n" + perf_matrix["total_return_episodes"] = len(sorted_episodes) + perf_matrix["total_time"] = time.perf_counter() - search_start + self.print_search_perf_matrix(perf_matrix) + return sorted_episodes, perf_matrix, False