8000 Nathan's branch by CodCodingCode · Pull Request #24 · CodCodingCode/AMIE-app · GitHub
[go: up one dir, main page]

Skip to content

Nathan's branch #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
39b65e7
Fixed diagnostic logic
CodCodingCode Jun 3, 2025
e87067f
added thinking and answer into json outputs
CodCodingCode Jun 7, 2025
3263a55
Update gen_convo.py
CodCodingCode Jun 7, 2025
24c54e5
testing prompt using specific vignette (there is a vignette issue wit…
CodCodingCode Jun 7, 2025
7acd9d2
fixed globalization issue
CodCodingCode Jun 7, 2025
7d12874
fixed defining function because of Extraction error
CodCodingCode Jun 7, 2025
da45f59
started new dataset generation
CodCodingCode Jun 7, 2025
b7d0088
generated new dataset
CodCodingCode Jun 8, 2025
7480df6
added deployment code to finally deploy model to website
CodCodingCode Jun 8, 2025
75e4726
Added print and break statements
CodCodingCode Jun 8, 2025
618a240
set up endpoint for the model - ALSMOT ABLE TO CONNECT TO FRONTEND
CodCodingCode Jun 8, 2025
71237bc
fixed major model thinking issue - tested model once again
CodCodingCode Jun 9, 2025
d1410fd
fixed prompting of the model. added openai caching
CodCodingCode Jun 9, 2025
1b3ada4
fixed prompting so that model would not hallucinate + changed summari…
CodCodingCode Jun 10, 2025
015032a
added new files
CodCodingCode Jun 10, 2025
9f58f7d
generated new interesting dataset
CodCodingCode Jun 11, 2025
9b07483
added reward for hallucination
CodCodingCode Jun 11, 2025
d8414af
Edited hallucination file - added patient training
CodCodingCode Jun 12, 2025
8ac87f5
fixed hallucination code fully.
CodCodingCode Jun 12, 2025
faebffa
created new endpoint to test full loop
CodCodingCode Jun 12, 2025
c2564aa
created code for trainin model to not hallucinate. Added multiple hal…
CodCodingCode Jun 12, 2025
222fa15
got rid of api keys lol
CodCodingCode Jun 12, 2025
2fca862
Changed prompting
CodCodingCode Jun 12, 2025
617ecfc
switched Diagnoer - changed patietn and summarizer logic
CodCodingCode Jun 12, 2025
dea890f
added further prompt enginerring to make model better
CodCodingCode Jun 12, 2025
3d89e46
added an evaluatoin meter than can easily evaluate the quality of the…
CodCodingCode Jun 12, 2025
a427988
debugging codebase for errors. added stirctness to evaluator. Made ev…
CodCodingCode Jun 13, 2025
0dea88b
Fixed file pathing for eval_data
CodCodingCode Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Edited hallucination file - added patient training
  • Loading branch information
CodCodingCode committed Jun 12, 2025
commit d8414af5d8acb91c278af29294d5787570de75a6
18 changes: 2 additions & 16 deletions new_data_gen/SFT /endpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# ============================================================================
# REPLACE THESE WITH YOUR ACTUAL VALUES
# ============================================================================
ENDPOINT_URL = "https://glg6vtpv72vt2jad.us-east-1.aws.endpoints.huggingface.cloud" # Your endpoint URL from the screenshot
HF_TOKEN = "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" # Your HuggingFace token
ENDPOINT_URL = "url" # Your endpoint URL from the screenshot
HF_TOKEN = "token" # Your HuggingFace token


class HuggingFaceInference:
Expand Down 8000 Expand Up @@ -259,20 +259,6 @@ def run_conversation():
print("\n✅ Conversation finished!")


# Optional: Test just one generation
def test_single_generation():
test_prompt = """
Instruction: You are a clinical summarizer. Given a transcript of a doctor–patient dialogue, extract a structured clinical vignette summarizing the key symptoms, relevant history, and any diagnostic clues.
Input: I am 14. I am a male. I have pain in my stomach. Previous Vignette:
Output: THINKING:
"""

print("🧪 Testing single generation...")
result = model_client.generate(test_prompt, max_new_tokens=200)
print("📝 Result:")
print(result)


if __name__ == "__main__":
print("🚀 Clinical AI with HuggingFace Inference Endpoint")
print(f"🔗 Endpoint: {ENDPOINT_URL}")
Expand Down
81 changes: 81 additions & 0 deletions new_data_gen/SFT /patient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3

import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
)

# 1) Load and inspect data
raw_ds = load_dataset("CodCodingCode/patient-agent-V1.3", split="train")
print(raw_ds.features, raw_ds[0])


# 2) Alpaca-style prompt formatter
def format_fn(example):
return {
"text": (
f"### Instruction:\nYou are a patient agent. Please act as if you are a real patient with the following vignette and conversation.\n\n"
f"### Input:\n{example['input']}\n\n"
f"### Output:\n{example['output']}"
)
}


ds = raw_ds.map(format_fn, remove_columns=raw_ds.column_names)

# 3) Tokenizer + base model in BF16
model_name = "meta-llama/Llama-3.1-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# Set pad token for batching
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
# Align pad_token_id
model.config.pad_token_id = tokenizer.eos_token_id


# 4) Tokenize and prepare labels for causal LM
def tokenize_fn(examples):
tokens = tokenizer(examples["text"], padding=True, truncation=True, max_length=2048)
# Use input_ids also as labels for full-model fine-tuning
tokens["labels"] = tokens.input_ids.copy()
return tokens


train_ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])

# 5) Training configuration for full fine-tuning
training_args = TrainingArguments(
output_dir="outputs/full_finetune",
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
max_steps=2250, # adjust as needed for your epochs
learning_rate=2e-5, # lower LR for full-model fine-tuning
optim="adamw_torch",
bf16=True, # GH200 supports BF16
fp16=False,
logging_steps=10,
save_strategy="steps",
save_steps=500,
save_total_limit=2,
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
tokenizer=tokenizer,
)

# 6) Fine-tune all model weights
trainer.train()

# 7) Save final model
trainer.save_model("outputs/full_finetune")
2 changes: 1 addition & 1 deletion new_data_gen/SFT /setup.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# • CUDA 12.8 + cuDNN (confirmed)

# 2) Create & activate a Python venv
cd ~/project
cd ~/train
python3 -m venv llama3-ft
source llama3-ft/bin/activate
pip install --upgrade pip setuptools wheel
Expand Down
82 changes: 47 additions & 35 deletions new_data_gen/grpo_infra/S/hallucination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import torch

# ─── 0. HF TOKEN ─────────────────────────────────────────────────
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
HF_TOKEN = "token"
print("[debug] HF_TOKEN:", HF_TOKEN[:8] + "…" if HF_TOKEN else None)
if not HF_TOKEN:
raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN")

# ─── 1. Download model + checkpoint via snapshot_download ────────
REPO_ID = "CodCodingCode/llama-3.1-8b-clinical-v1.2"
SUBFOLDER = "checkpoint-4500"
REPO_ID = "CodCodingCode/llama-3.1-8b-clinical-v1.3"
SUBFOLDER = "checkpoint-6508"
print(f"[debug] Downloading {REPO_ID}…")
cache_dir = snapshot_download(repo_id=REPO_ID, token=HF_TOKEN)
print("[debug] snapshot_download complete, cache_dir:", cache_dir)
Expand Down Expand Up @@ -52,7 +52,11 @@
def make_prompt(ex):
instr = ex["instruction"].strip()
inp = ex.get("input", "").strip()
full = instr + ("\n" + inp if inp else "")

full = f"""Instruction: {instr}
Input: {("\n" + inp if inp else "")}
Output: THINKING:
"""
return {"prompt": full}


Expand Down Expand Up @@ -93,11 +97,10 @@ def tokenize_batch(batch):

# ─── 6. Define your ChatGPT-based anti-hallucination reward fn ────────────────────
from openai import OpenAI
import time
import multiprocessing
import shutil
from itertools import islice
import random
import json


# Initialize OpenAI client
client = OpenAI(api_key="api")
Expand All @@ -108,51 +111,57 @@ def chatgpt_hallucination_reward(prompts, completions, **kwargs):
rewards = []
for idx, (prompt, completion) in enumerate(zip(prompts, completions)):

# Extract patient input from prompt if it's a clinical summarizer task
# Extract conversation data if it's a clinical summarizer task
if "clinical summarizer" in prompt.lower():
input_match = re.search(
r"Input:\s*(.*?)\s*(?:Previous Vignette|Output:|$)", prompt, re.DOTALL
)
if input_match:
patient_input = input_match.group(1).strip()
# Extract conversation history
conv_match = re.search(r"CONVERSATION HISTORY:\s*(\[.*?\])", prompt, re.DOTALL)
# Extract previous vignette
vignette_match = re.search(r"PREVIOUS VIGNETTE:\s*(.*?)(?:Output:|$)", prompt, re.DOTALL)

if conv_match:
conversation_history = conv_match.group(1).strip()
previous_vignette = vignette_match.group(1).strip() if vignette_match else ""

# Create ChatGPT prompt to evaluate hallucination
evaluation_prompt = f"""
You are an expert clinical fact-checker. Your job is to compare a patient's original statement with a clinical summary and identify any hallucinations or inaccuracies.
evaluation_prompt = f"""You are an expert clinical fact-checker. Compare the conversation + previous vignette with the new clinical summary to find hallucinations.

CONVERSATION HISTORY:
{conversation_history}

PATIENT'S ORIGINAL STATEMENT:
{patient_input}
PREVIOUS VIGNETTE:
{previous_vignette}

CLINICAL SUMMARY TO EVALUATE:
NEW CLINICAL SUMMARY TO CHECK:
{completion}

Please analyze if the clinical summary adds any information that was NOT mentioned by the patient. Look for:
1. Symptoms the patient never mentioned
2. Demographic information that contradicts what the patient said
3. Severity descriptions not provided by the patient
4. Any other fabricated details
Find if the summary adds information NOT in the conversation or previous vignette.

Respond with a JSON object containing:
- "hallucinated_items": [list of specific things that were added/fabricated]
- "accurate_items": [list of things correctly extracted from patient statement]
- "score": a number from -10 to +10 (-10 = severe hallucination, +10 = perfect accuracy)
JSON response:
- "hallucinated_items": [things added that weren't in conversation/vignette]
- "accurate_items": [things correctly from conversation/vignette]
- "score": -10 to +10 (-10=bad hallucinations, +10=perfect)

Example response:
{{"hallucinated_items": ["dizziness", "repeated vomiting"], "accurate_items": ["17-year-old female", "left-sided headache", "photophobia"], "score": -2}}
"""
{{"hallucinated_items": [], "accurate_items": [], "score": 0}}"""

try:
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": evaluation_prompt}],
temperature=0.1,
max_tokens=500,
max_tokens=200,
)

# Parse the response
import json

result = json.loads(response.choices[0].message.content)
# Get raw response and try to extract JSON
raw_content = response.choices[0].message.content

# Try to extract JSON if wrapped in markdown
json_match = re.search(r'\{.*\}', raw_content, re.DOTALL)
if json_match:
json_str = json_match.group(0)
else:
json_str = raw_content

result = json.loads(json_str)
score = result.get("score", 0.0)
hallucinated = result.get("hallucinated_items", [])
accurate = result.get("accurate_items", [])
Expand All @@ -161,6 +170,9 @@ def chatgpt_hallucination_reward(prompts, completions, **kwargs):
f"[debug reward] #{idx} → hallucinated: {hallucinated}, accurate: {accurate}, score: {score}"
)

except json.JSONDecodeError as e:
print(f"[debug reward] #{idx} → JSON error: {e}, raw: {response.choices[0].message.content}")
score = 0.0
except Exception as e:
print(
f"[debug reward] #{idx} → ChatGPT error: {e}, defaulting to 0.0"
Expand Down
0