8000 strictly type python backend · ag-python/screenshot-to-code@6a28ee2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a28ee2

Browse files
committed
strictly type python backend
1 parent 68a8d27 commit 6a28ee2

File tree

8 files changed

+66
-37
lines changed

8 files changed

+66
-37
lines changed

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.analysis.typeCheckingMode": "strict"
3+
}

backend/image_generation.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import asyncio
2-
import os
32
import re
3+
from typing import Dict, List, Union
44
from openai import AsyncOpenAI
55
from bs4 import BeautifulSoup
66

77

8-
async def process_tasks(prompts, api_key, base_url):
8+
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
99
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
1010
results = await asyncio.gather(*tasks, return_exceptions=True)
1111

12-
processed_results = []
12+
processed_results: List[Union[str, None]] = []
1313
for result in results:
1414
if isinstance(result, Exception):
1515
print(f"An exception occurred: {result}")
@@ -20,9 +20,9 @@ async def process_tasks(prompts, api_key, base_url):
2020
return processed_results
2121

2222

23-
async def generate_image(prompt, api_key, base_url):
23+
async def generate_image(prompt: str, api_key: str, base_url: str):
2424
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
25-
image_params = {
25+
image_params: Dict[str, Union[str, int]] = {
2626
"model": "dall-e-3",
2727
"quality": "standard",
2828
"style": "natural",
@@ -35,7 +35,7 @@ async def generate_image(prompt, api_key, base_url):
3535
return res.data[0].url
3636

3737

38-
def extract_dimensions(url):
38+
def extract_dimensions(url: str):
3939
# Regular expression to match numbers in the format '300x200'
4040
matches = re.findall(r"(\d+)x(\d+)", url)
4141

@@ -48,11 +48,11 @@ def extract_dimensions(url):
4848
return (100, 100)
4949

5050

51-
def create_alt_url_mapping(code):
51+
def create_alt_url_mapping(code: str) -> Dict[str, str]:
5252
soup = BeautifulSoup(code, "html.parser")
5353
images = soup.find_all("img")
5454

55-
mapping = {}
55+
mapping: Dict[str, str] = {}
5656

5757
for image in images:
5858
if not image["src"].startswith("https://placehold.co"):
@@ -61,7 +61,9 @@ def create_alt_url_mapping(code):
6161
return mapping
6262

6363

64-
async def generate_images(code, api_key, base_url, image_cache):
64+
async def generate_images(
65+
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
66+
):
6567
# Find all images
6668
soup = BeautifulSoup(code, "html.parser")
6769
images = soup.find_all("img")

backend/llm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import os
2-
from typing import Awaitable, Callable
1+
from typing import Awaitable, Callable, List
32
from openai import AsyncOpenAI
3+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
44

55
MODEL_GPT_4_VISION = "gpt-4-vision-preview"
66

77

88
async def stream_openai_response(
9-
messages,
9+
messages: List[ChatCompletionMessageParam],
1010
api_key: str,
1111
base_url: str | None,
1212
callback: Callable[[str], Awaitable[None]],
13-
):
13+
) -> str:
1414
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
1515

1616
model = MODEL_GPT_4_VISION
@@ -23,9 +23,10 @@ async def stream_openai_response(
2323
params["max_tokens"] = 4096
2424
params["temperature"] = 0
2525

26-
completion = await client.chat.completions.create(**params)
26+
stream = await client.chat.completions.create(**params) # type: ignore
2727
full_response = ""
28-
async for chunk in completion:
28+
async for chunk in stream: # type: ignore
29+
assert isinstance(chunk, ChatCompletionChunk)
2930
content = chunk.choices[0].delta.content or ""
3031
full_response += content
3132
await callback(content)

backend/main.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from fastapi.responses import HTMLResponse
1515
import openai
1616
from llm import stream_openai_response
17-
from mock import mock_completion
17+
from openai.types.chat import ChatCompletionMessageParam
18+
from mock_llm import mock_completion
1819
from utils import pprint_prompt
20+
from typing import Dict, List
1921
from image_generation import create_alt_url_mapping, generate_images
2022
from prompts import assemble_prompt
2123
from routes import screenshot
@@ -53,7 +55,7 @@ async def get_status():
5355
)
5456

5557

56-
def write_logs(prompt_messages, completion):
58+
def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
5759
# Get the logs path from environment, default to the current working directory
5860
logs_path = os.environ.get("LOGS_PATH", os.getcwd())
5961

@@ -84,7 +86,8 @@ async def throw_error(
8486
await websocket.send_json({"type": "error", "value": message})
8587
await websocket.close()
8688

87-
params = await websocket.receive_json()
89+
# TODO: Are the values always strings?
90+
params: Dict[str, str] = await websocket.receive_json()
8891

8992
print("Received params")
9093

@@ -154,7 +157,7 @@ async def throw_error(
154157
print("generating code...")
155158
await websocket.send_json({"type": "status", "value": "Generating code..."})
156159

157-
async def process_chunk(content):
160+
async def process_chunk(content: str):
158161
await websocket.send_json({"type": "chunk", "value": content})
159162

160163
# Assemble the prompt
@@ -176,15 +179,23 @@ async def process_chunk(content):
176179
return
177180

178181
# Image cache for updates so that we don't have to regenerate images
179-
image_cache = {}
182+
image_cache: Dict[str, str] = {}
180183

181184
if params["generationType"] == "update":
182185
# Transform into message format
183186
# TODO: Move this to frontend
184187
for index, text in enumerate(params["history"]):
185-
prompt_messages += [
186-
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
187-
]
188+
if index % 2 == 0:
189+
message: ChatCompletionMessageParam = {
190+
"role": "assistant",
191+
"content": text,
192+
}
193+
else:
194+
message: ChatCompletionMessageParam = {
195+
"role": "user",
196+
"content": text,
197+
}
198+
prompt_messages.append(message)
188199
image_cache = create_alt_url_mapping(params["history"][-2])
189200

190201
if SHOULD_MOCK_AI_RESPONSE:

backend/mock.py renamed to backend/mock_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
2+
from typing import Awaitable, Callable
23

34

4-
async def mock_completion(process_chunk):
5+
async def mock_completion(process_chunk: Callable[[str], Awaitable[None]]) -> str:
56
code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE
67

78
for i in range(0, len(code_to_return), 10):

backend/prompts.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from typing import List, Union
2+
3+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam
4+
5+
16
TAILWIND_SYSTEM_PROMPT = """
27
You are an expert Tailwind developer
38
You take screenshots of a reference web page from the user, and then build single page apps
@@ -117,8 +122,10 @@
117122

118123

119124
def assemble_prompt(
120-
image_data_url, generated_code_config: str, result_image_data_url=None
121-
):
125+
image_data_url: str,
126+
generated_code_config: str,
127+
result_image_data_url: Union[str, None] = None,
128+
) -> List[ChatCompletionMessageParam]:
122129
# Set the system prompt based on the output settings
123130
system_content = TAILWIND_SYSTEM_PROMPT
124131
if generated_code_config == "html_tailwind":
@@ -132,7 +139,7 @@ def assemble_prompt(
132139
else:
133140
raise Exception("Code config is not one of available options")
134141

135-
user_content = [
142+
user_content: List[ChatCompletionContentPartParam] = [
136143
{
137144
"type": "image_url",
138145
"image_url": {"url": image_data_url, "detail": "high"},

backend/routes/screenshot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ def bytes_to_data_url(image_bytes: bytes, mime_type: str) -> str:
1111
return f"data:{mime_type};base64,{base64_image}"
1212

1313

14-
async def capture_screenshot(target_url, api_key, device="desktop") -> bytes:
14+
async def capture_screenshot(
15+
target_url: str, api_key: str, device: str = "desktop"
16+
) -> bytes:
1517
api_base_url = "https://api.screenshotone.com/take"
1618

1719
params = {

backend/utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
import copy
22
import json
3+
from typing import List
4+
from openai.types.chat import ChatCompletionMessageParam
35

46

5-
def pprint_prompt(prompt_messages):
7+
def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]):
68
print(json.dumps(truncate_data_strings(prompt_messages), indent=4))
79

810

9 D2D6 -
def truncate_data_strings(data):
11+
def truncate_data_strings(data: List[ChatCompletionMessageParam]): # type: ignore
1012
# Deep clone the data to avoid modifying the original object
1113
cloned_data = copy.deepcopy(data)
1214

1315
if isinstance(cloned_data, dict):
14-
for key, value in cloned_data.items():
16+
for key, value in cloned_data.items(): # type: ignore
1517
# Recursively call the function if the value is a dictionary or a list
1618
if isinstance(value, (dict, list)):
17-
cloned_data[key] = truncate_data_strings(value)
19+
cloned_data[key] = truncate_data_strings(value) # type: ignore
1820
# Truncate the string if it it's long and add ellipsis and length
1921
elif isinstance(value, str):
20-
cloned_data[key] = value[:40]
22+
cloned_data[key] = value[:40] # type: ignore
2123
if len(value) > 40:
22-
cloned_data[key] += "..." + f" ({len(value)} chars)"
24+
cloned_data[key] += "..." + f" ({len(value)} chars)" # type: ignore
2325

24-
elif isinstance(cloned_data, list):
26+
elif isinstance(cloned_data, list): # type: ignore
2527
# Process each item in the list
26-
cloned_data = [truncate_data_strings(item) for item in cloned_data]
28+
cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore
2729

28-
return cloned_data
30+
return cloned_data # type: ignore

0 commit comments

Comments
 (0)
0