diff --git a/.env.example b/.env.example index 6fb8013..cbc7fc7 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,7 @@ POSTGRES_URL="postgresql://postgres:password@localhost:5432/gitdiagram" NEXT_PUBLIC_API_DEV_URL=http://localhost:8000 - +ENABLE_LOCAL_SERVER=False +LOCAL_CODEBASE="/app/code" OPENAI_API_KEY= # OPTIONAL: providing your own GitHub PAT increases rate limits from 60/hr to 5000/hr to the GitHub API diff --git a/README.md b/README.md index c5ef8eb..3d96bd8 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,55 @@ pnpm dev You can now access the website at `localhost:3000` and edit the rate limits defined in `backend/app/routers/generate.py` in the generate function decorator. +## How to run it on local direcotry +Make sure you already can go through the Self-hosting correctly. +1. Create env file +```bash +cp .env.example .env +``` +Change the `ENABLE_LOCAL_SERVER` to `True` + +2. Mount directrory +Change the `docker-compose.yml`, mount your local directory at volumns. For example, if you want to generate the diagram of directory `/Users/xxx/my-project`. Add one line: `- /Users/xxx/my-project:/app/code` like below: +``` +services: + api: + build: + context: ./backend + dockerfile: Dockerfile + ports: + - "8000:8000" + volumes: + - ./backend:/app + - /Users/xxx/my-project:/app/code + env_file: + - .env + environment: + - ENVIRONMENT=${ENVIRONMENT:-development} # Default to development if not set + restart: unless-stopped +``` + +3. Run the following command to launch service. +run it. +```bash +docker-compose up --build -d +chmod +x start-database.sh +./start-database.sh +pnpm db:push +pnpm dev +``` + +4. Go `localhost:3000` and just input any valide github url, eg `https://github.com/yufansong/gitdiagram`. It won't real generate the diagram of that github url, but will trigger the logic to generate the diagram of your local directory previously assigned. + +5. If you meet the "syntax error" like this issue: `https://github.com/ahmedkhaleel2004/gitdiagram/issues/64`. It result from the lack of modal ability. The LLM generated mermaid js is not correct. My temprory solution is: +- Go `backend` directory, you will find `mermaid.txt`. +- Throw it into an online mermaid editor like this `https://www.mermaidchart.com/play`, then you should get an error if you input the content of `mermaid.txt`. +- Put the `mermaid.txt` and the error log into GPT, let it give you a correct mermard code. +- Go back to `https://www.mermaidchart.com/play` and retry, you will get the result. + +As least for me, I can get correctly result for several times by above solution. + + ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py index 7ed649e..fd93454 100644 --- a/backend/app/routers/generate.py +++ b/backend/app/routers/generate.py @@ -2,7 +2,9 @@ from fastapi.responses import StreamingResponse from dotenv import load_dotenv from app.services.github_service import GitHubService +from app.services.local_service import LocalService from app.services.o3_mini_openai_service import OpenAIo3Service +from app.services.seed_dpsk_service import SeedDpSkService from app.prompts import ( SYSTEM_FIRST_PROMPT, SYSTEM_SECOND_PROMPT, @@ -15,12 +17,16 @@ import re import json import asyncio - # from app.services.claude_service import ClaudeService # from app.core.limiter import limiter - +import os load_dotenv() - +ENABLE_LOCAL_SERVER = os.getenv("ENABLE_LOCAL_SERVER") == "true" \ + or os.getenv("ENABLE_LOCAL_SERVER") == "True" \ + or os.getenv("ENABLE_LOCAL_SERVER") == "TRUE" \ + or os.getenv("ENABLE_LOCAL_SERVER") == "1" \ + or os.getenv("ENABLE_LOCAL_SERVER") == 1 +LOCAL_CODEBASE = os.getenv("LOCAL_CODEBASE") router = APIRouter(prefix="/generate", tags=["Claude"]) # Initialize services @@ -28,21 +34,39 @@ o3_service = OpenAIo3Service() +def dump_mermaid_code(mermaid_code): + output_dir = "/app" + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, "mermaid.txt"), "w") as f: + f.write(mermaid_code) # directly write to mermaid.txt + return + + # cache github data to avoid double API calls from cost and generate @lru_cache(maxsize=100) def get_cached_github_data(username: str, repo: str, github_pat: str | None = None): + if ENABLE_LOCAL_SERVER: + return get_cached_local_data(local_path=LOCAL_CODEBASE) # Create a new service instance for each call with the appropriate PAT - current_github_service = GitHubService(pat=github_pat) + service = GitHubService(pat=github_pat) - default_branch = current_github_service.get_default_branch(username, repo) + default_branch = service.get_default_branch(username, repo) if not default_branch: default_branch = "main" # fallback value - file_tree = current_github_service.get_github_file_paths_as_list(username, repo) - readme = current_github_service.get_github_readme(username, repo) + file_tree = service.get_file_paths_as_list(username, repo) + readme = service.get_readme(username, repo) return {"default_branch": default_branch, "file_tree": file_tree, "readme": readme} +@lru_cache(maxsize=100) +def get_cached_local_data(local_path: str): + # Create a new service instance for each call with the appropriate PAT + service = LocalService(path=local_path) + file_tree = service.get_file_paths_as_list() + readme = service.get_readme() + return {"default_branch": "Your Current Branch", "file_tree": file_tree, "readme": readme} + class ApiRequest(BaseModel): username: str @@ -246,6 +270,10 @@ async def event_generator(): mermaid_code, body.username, body.repo, default_branch ) + # dump the generated mermaid code + # will be useful when debugging + dump_mermaid_code(mermaid_code) + # Send final result yield f"data: {json.dumps({ 'status': 'complete', diff --git a/backend/app/services/base_service.py b/backend/app/services/base_service.py new file mode 100644 index 0000000..8989ba6 --- /dev/null +++ b/backend/app/services/base_service.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod + +class BaseService(ABC): + @abstractmethod + def get_file_paths_as_list(self, username, repo): + raise NotImplementedError + + @abstractmethod + def get_readme(self, username, repo): + raise NotImplementedError + + # Shared utility method can be implemented here + def _should_include_file(self, path): + # Patterns to exclude + excluded_patterns = [ + # Dependencies + "node_modules/", + "vendor/", + "venv/", + # Compiled files + ".min.", + ".pyc", + ".pyo", + ".pyd", + ".so", + ".dll", + ".class", + # Asset files + ".jpg", + ".jpeg", + ".png", + ".gif", + ".ico", + ".svg", + ".ttf", + ".woff", + ".webp", + # Cache and temporary files + "__pycache__/", + ".cache/", + ".tmp/", + # Lock files and logs + "yarn.lock", + "poetry.lock", + "*.log", + # Configuration files + ".vscode/", + ".idea/", + ] + + return not any(pattern in path.lower() for pattern in excluded_patterns) + \ No newline at end of file diff --git a/backend/app/services/github_service.py b/backend/app/services/github_service.py index 33a4d42..ca13f24 100644 --- a/backend/app/services/github_service.py +++ b/backend/app/services/github_service.py @@ -4,11 +4,11 @@ from datetime import datetime, timedelta from dotenv import load_dotenv import os +from .base_service import BaseService load_dotenv() - -class GitHubService: +class GitHubService(BaseService): def __init__(self, pat: str | None = None): # Try app authentication first self.client_id = os.getenv("GITHUB_CLIENT_ID") @@ -107,7 +107,7 @@ def get_default_branch(self, username, repo): return response.json().get("default_branch") return None - def get_github_file_paths_as_list(self, username, repo): + def get_file_paths_as_list(self, username, repo): """ Fetches the file tree of an open-source GitHub repository, excluding static files and generated code. @@ -119,47 +119,6 @@ def get_github_file_paths_as_list(self, username, repo): Returns: str: A filtered and formatted string of file paths in the repository, one per line. """ - - def should_include_file(path): - # Patterns to exclude - excluded_patterns = [ - # Dependencies - "node_modules/", - "vendor/", - "venv/", - # Compiled files - ".min.", - ".pyc", - ".pyo", - ".pyd", - ".so", - ".dll", - ".class", - # Asset files - ".jpg", - ".jpeg", - ".png", - ".gif", - ".ico", - ".svg", - ".ttf", - ".woff", - ".webp", - # Cache and temporary files - "__pycache__/", - ".cache/", - ".tmp/", - # Lock files and logs - "yarn.lock", - "poetry.lock", - "*.log", - # Configuration files - ".vscode/", - ".idea/", - ] - - return not any(pattern in path.lower() for pattern in excluded_patterns) - # Try to get the default branch first branch = self.get_default_branch(username, repo) if branch: @@ -174,7 +133,7 @@ def should_include_file(path): paths = [ item["path"] for item in data["tree"] - if should_include_file(item["path"]) + if self._should_include_file(item["path"]) ] return "\n".join(paths) @@ -191,7 +150,7 @@ def should_include_file(path): paths = [ item["path"] for item in data["tree"] - if should_include_file(item["path"]) + if self._should_include_file(item["path"]) ] return "\n".join(paths) @@ -199,7 +158,7 @@ def should_include_file(path): "Could not fetch repository file tree. Repository might not exist, be empty or private." ) - def get_github_readme(self, username, repo): + def get_readme(self, username, repo): """ Fetches the README contents of an open-source GitHub repository. diff --git a/backend/app/services/local_service.py b/backend/app/services/local_service.py new file mode 100644 index 0000000..163ba3d --- /dev/null +++ b/backend/app/services/local_service.py @@ -0,0 +1,73 @@ +import requests +import jwt +import time +from datetime import datetime, timedelta +from dotenv import load_dotenv +import os +from pathlib import Path +from .base_service import BaseService + +load_dotenv() + +class LocalService(BaseService): # Changed to inherit from BaseService + def __init__(self, path: str | None = None): + self.path = path + base_path = Path(self.path) + if not base_path.exists(): + raise ValueError(f"Local path not exist {self.path}") + if not base_path.is_dir(): + raise ValueError("Should provide a valid directory path") + + def get_file_paths_as_list(self): + """ + Get the file paths of the local codebase, excluding static files and generated code. + Returns: + str: file path list after filtering + """ + def scan_directory(base_path: Path): + """recursively scan directories and filter files""" + paths = [] + for entry in base_path.iterdir(): + if entry.name.startswith('.'): + continue + if entry.is_dir(): + paths.extend(scan_directory(entry)) + else: + file_path = str(entry.relative_to(base_path)) + if self._should_include_file(file_path): + paths.append(file_path) + return paths + + try: + base_path = Path(self.path) + if not base_path.exists(): + raise ValueError(f"local path not exist {self.path}") + if not base_path.is_dir(): + raise ValueError("should provide the directory path") + + all_files = scan_directory(base_path) + return "\n".join(all_files) + + except Exception as e: + raise ValueError(f"cannot read the local codebase: {str(e)}") + + def get_readme(self): + """ + Get the README file content of the local codebase. + Returns: + str: README file content + + Raises: + ValueError: throw when readme file is not found + FileNotFoundError: throw when readme file is not found + """ + readme_files = ['README.md', 'README', 'readme.md'] + for filename in readme_files: + readme_path = Path(self.path) / filename + if readme_path.is_file(): + try: + return readme_path.read_text(encoding='utf-8') + except UnicodeDecodeError: + continue + + raise FileNotFoundError("Cannot find the available readme file (support README.md/README/readme.md)") \ No newline at end of file diff --git a/backend/app/services/seed_dpsk_service.py b/backend/app/services/seed_dpsk_service.py new file mode 100644 index 0000000..8d8961d --- /dev/null +++ b/backend/app/services/seed_dpsk_service.py @@ -0,0 +1,120 @@ +from openai import OpenAI +from dotenv import load_dotenv +from app.utils.format_message import format_user_message +import tiktoken +import os +import aiohttp +import json +from typing import AsyncGenerator, Literal + +load_dotenv() + + +class SeedDpSkService: + def __init__(self): + self.api_key = "" + self.encoding = tiktoken.get_encoding("o200k_base") + # 修正请求路径(添加 /chat/completions) + self.base_url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" + + async def call_o3_api_stream( + self, + system_prompt: str, + data: dict, + api_key: str | None = None, + reasoning_effort: Literal["low", "medium", "high"] = "low", + ) -> AsyncGenerator[str, None]: + """ + Makes a streaming API call to OpenAI o3-mini and yields the responses. + + Args: + system_prompt (str): The instruction/system prompt + data (dict): Dictionary of variables to format into the user message + api_key (str | None): Optional custom API key + + Yields: + str: Chunks of o3-mini's response text + """ + # Create the user message with the data + user_message = format_user_message(data) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + + payload = { + "model": "deepseek-r1-250120", + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + # 添加缺失的必需参数 + "max_completion_tokens": 12000, + "stream": True, + # 使用传入的 reasoning_effort 参数 + "reasoning_effort": reasoning_effort, + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + self.base_url, + headers=headers, + json=payload + ) as response: + + if response.status != 200: + error_text = await response.text() + print(f"Error response: {error_text}") + raise ValueError( + f"OpenAI API returned status code {response.status}: {error_text}" + ) + + line_count = 0 + async for line in response.content: + line = line.decode("utf-8").strip() + if not line: + continue + + line_count += 1 + + if line.startswith("data: "): + if line == "data: [DONE]": + break + try: + data = json.loads(line[6:]) + content = ( + data.get("choices", [{}])[0] + .get("delta", {}) + .get("content") + ) + if content: + yield content + except json.JSONDecodeError as e: + print(f"JSON decode error: {e} for line: {line}") + continue + + if line_count == 0: + print("Warning: No lines received in stream response") + + except aiohttp.ClientError as e: + print(f"Connection error: {str(e)}") + raise ValueError(f"Failed to connect to OpenAI API: {str(e)}") + except Exception as e: + print(f"Unexpected error in streaming API call: {str(e)}") + raise + + def count_tokens(self, prompt: str) -> int: + """ + Counts the number of tokens in a prompt. + + Args: + prompt (str): The prompt to count tokens for + + Returns: + int: Estimated number of input tokens + """ + num_tokens = len(self.encoding.encode(prompt)) + return num_tokens