from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.concurrency import run_in_threadpool from gradio_client import Client import gradio as gr import uvicorn import httpx import websockets import asyncio from urllib.parse import urljoin, urlparse, unquote import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() class HttpClient: def __init__(self): # Configure the HTTP client with appropriate timeouts self.client = httpx.AsyncClient( timeout=httpx.Timeout(30.0), follow_redirects=False ) async def forward_request(self, request: Request, target_url: str): """ Forward an incoming request to a target URL """ try: # Extract method, headers, and body from the incoming request method = request.method headers = dict(request.headers) # Remove headers that shouldn't be forwarded headers.pop("host", None) headers.pop("connection", None) # Get the request body body = await request.body() logger.info(f"Forwarding {method} request to {target_url}") # Forward the request to the target URL response = await self.client.request( method=method, url=target_url, headers=headers, content=body ) # Handle the response from the target server response_headers = dict(response.headers) # Remove headers that shouldn't be forwarded from the response response_headers.pop("connection", None) response_headers.pop("transfer-encoding", None) return Response( content=response.content, status_code=response.status_code, headers=response_headers ) except httpx.TimeoutException: logger.error(f"Timeout error while forwarding request to {target_url}") return Response( content="Request timeout error", status_code=504 ) except httpx.NetworkError as e: logger.error(f"Network error while forwarding request: {str(e)}") return Response( content=f"Network error: {str(e)}", status_code=502 ) except Exception as e: logger.error(f"Error forwarding request: {str(e)}") return Response( content=f"Request error: {str(e)}", status_code=500 ) async def close(self): await self.client.aclose() # Initialize the HTTP client http_client = HttpClient() @app.get("/", response_class=HTMLResponse) async def read_root(): with open("index.html") as f: return f.read() @app.post("/gp/{repo_id:path}/{api_name:path}") async def gradio_client(repo_id: str, api_name: str, request: Request): client = Client(repo_id) data = await request.json() result = await run_in_threadpool(client.predict, *data["args"], api_name=f"/{api_name}") return result @app.api_route("/wp/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) async def web_client(request: Request, path: str): """ Main web client endpoint that forwards all requests to the target URL specified in the path or the 'X-Target-Url' header """ # Prioritize URL in path if it starts with http:// or https:// if path.lower().startswith("http://") or path.lower().startswith("https://"): target_url = path else: # Get the target URL from the header target_url = request.headers.get("X-Target-Url") # If we have a target URL from header and a path, combine them if target_url and path: # Validate the target URL from header try: parsed_url = urlparse(target_url) if not parsed_url.scheme or not parsed_url.netloc: return Response( content="Invalid X-Target-Url header", status_code=400 ) except Exception: return Response( content="Invalid X-Target-Url header", status_code=400 ) # Join the target URL with the path properly target_url = urljoin(target_url.rstrip('/') + '/', path.lstrip('/')) if not target_url: return Response( content="Missing X-Target-Url header or URL in path", status_code=400 ) # Validate the target URL try: parsed_url = urlparse(target_url) if not parsed_url.scheme or not parsed_url.netloc: return Response( content="Invalid X-Target-Url header or URL in path", status_code=400 ) except Exception: return Response( content="Invalid X-Target-Url header or URL in path", status_code=400 ) # Forward the request return await http_client.forward_request(request, target_url) @app.websocket("/wp/{path:path}") async def websocket_client(websocket: WebSocket, path: str): """ WebSocket endpoint that forwards WebSocket connections to the target URL specified in the 'X-Target-Url' header or in the path """ # Get the target URL from the header or path target_url = websocket.headers.get("X-Target-Url") # If no header, use path as target URL if it's a valid WebSocket URL if not target_url: # Handle URL-encoded paths decoded_path = path if path and '%' in path: # URL decode the path decoded_path = unquote(path) if decoded_path and (decoded_path.lower().startswith("ws://") or decoded_path.lower().startswith("wss://")): target_url = decoded_path else: await websocket.close(code=1008, reason="Missing X-Target-Url header or invalid URL in path") return # Validate the target URL try: parsed_url = urlparse(target_url) if not parsed_url.scheme or not parsed_url.netloc: await websocket.close(code=1008, reason="Invalid target URL") return except Exception: await websocket.close(code=1008, reason="Invalid target URL") return # Accept the WebSocket connection await websocket.accept() # Convert HTTP/HTTPS URL to WebSocket URL if target_url.lower().startswith("https://"): ws_target_url = "wss://" + target_url[8:] elif target_url.lower().startswith("http://"): ws_target_url = "ws://" + target_url[7:] else: ws_target_url = target_url # Add path if provided (but only if it's not already a complete URL) if path and not (path.lower().startswith("ws://") or path.lower().startswith("wss://")): # Join the target URL with the path properly ws_target_url = urljoin(ws_target_url.rstrip('/') + '/', path.lstrip('/')) try: # Connect to the target WebSocket server async with websockets.connect(ws_target_url) as target_ws: # Forward messages between client and target server async def forward_client_to_target(): try: while True: data = await websocket.receive_text() await target_ws.send(data) except WebSocketDisconnect: pass async def forward_target_to_client(): try: while True: data = await target_ws.recv() await websocket.send_text(data) except websockets.ConnectionClosed: pass # Run both forwarding tasks concurrently await asyncio.gather( forward_client_to_target(), forward_target_to_client(), return_exceptions=True ) except websockets.InvalidURI: await websocket.close(code=1008, reason="Invalid WebSocket URL") except websockets.InvalidHandshake: await websocket.close(code=1008, reason="WebSocket handshake failed") except Exception as e: logger.error(f"Error in WebSocket connection: {str(e)}") await websocket.close(code=1011, reason="Internal server error") finally: try: await websocket.close() except: pass @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "ok"} @app.on_event("shutdown") async def shutdown_event(): await http_client.close() if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)