diff --git a/main.py b/main.py index ca26b8a..47f9d05 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ from pathlib import Path +from typing import Optional -from fastapi import FastAPI, Response, status +from fastapi import FastAPI, status from fastapi.responses import JSONResponse from pydantic import BaseModel @@ -20,10 +21,31 @@ class Log(BaseModel): end: bool +class SearchResult(BaseModel): + matches: list[int] + + class Error(BaseModel): error: str +def handle_path(name: str) -> tuple[Path, Optional[JSONResponse]]: + path = data_path.joinpath(name) + + # Prevent path traversal + if not path.is_relative_to(data_path): + return path, JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, content={"error": "Bad Request"} + ) + + if not path.is_file(): + return path, JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, content={"error": "Log not found"} + ) + + return path, None + + @app.get("/", response_model=Files) def list_logs(): return {"files": [path.name for path in data_path.glob("*.txt") if path.is_file()]} @@ -38,18 +60,9 @@ def list_logs(): }, ) def get_log(name: str, start: int = 0, size: int = 100): - path = data_path.joinpath(name) - - # Prevent path traversal - if not path.is_relative_to(data_path): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, content={"error": "Bad Request"} - ) - - if not path.is_file(): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, content={"error": "Log not found"} - ) + path, resp = handle_path(name) + if resp: + return resp with open(path) as f: f.seek(start) @@ -64,3 +77,48 @@ def get_log(name: str, start: int = 0, size: int = 100): "content": content, "end": size + start >= total, } + + +@app.get( + "/log/{name}/search/", + response_model=SearchResult, + responses={ + 400: {"model": Error}, + 404: {"model": Error}, + }, +) +def search_log(name: str, query: str, start: int = 0): + path, resp = handle_path(name) + if resp: + return resp + + matches = [] + buffer_size = max(len(query), 1024) + break_outer = False + with open(path) as f: + f.seek(start) + head = start + tail = start + 2 * buffer_size + buffer = f.read(tail) + while True: + offset = 0 + while True: + index = buffer.find(query, offset) + if index != -1: + matches.append(head + index) + offset = index + 1 + if len(matches) >= 100: + break_outer = True + break + else: + break + if break_outer: + break + + data = f.read(buffer_size) + if not data: + break + buffer = buffer[-len(query):] + data + head = tail - len(query) + tail += len(data) + return {"matches": matches}