From cb9cb2b047f56e20b5ccc43772488c2e8bdbc858 Mon Sep 17 00:00:00 2001 From: "Autumn.home" Date: Thu, 18 Jul 2024 23:50:50 +0800 Subject: [PATCH] add poc import --- api/poc.py | 128 +++++++++++++++++++++++++++++++++++++++------------ core/util.py | 2 +- main.py | 2 +- 3 files changed, 101 insertions(+), 31 deletions(-) diff --git a/api/poc.py b/api/poc.py index 4a6d22e..f48b141 100644 --- a/api/poc.py +++ b/api/poc.py @@ -3,11 +3,15 @@ # @auth: rainy-autumn@outlook.com # @version: import os +import shutil +import traceback import yaml from bson import ObjectId from fastapi import APIRouter, Depends, File, UploadFile from motor.motor_asyncio import AsyncIOMotorCursor +from starlette.background import BackgroundTasks + from api.users import verify_token from core.db import get_mongo_db from pymongo import ASCENDING, DESCENDING @@ -53,39 +57,102 @@ def is_safe_path(base_path, target_path): @router.post("/poc/data/import") -async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token)): +async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token), background_tasks: BackgroundTasks = BackgroundTasks()): if not file.filename.endswith('.zip'): return {"message": "not zip", "code": 500} - file_name = generate_random_string(5) - relative_path = f'file\\{file_name}.zip' - zip_file_path = os.path.join(os.getcwd(), relative_path) - with open(zip_file_path, "wb") as f: - f.write(await file.read()) - yaml_files = [] - unzip_path = f'file\\{file_name}' - file_path = os.path.join(os.getcwd(), unzip_path) - extract_path = file_path - os.makedirs(extract_path, exist_ok=True) - with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: - for member in zip_ref.namelist(): - member_path = os.path.join(extract_path, member) - if not is_safe_path(extract_path, member_path): - return {"message": "Unsafe file path detected in ZIP file", "code": 500} - zip_ref.extractall(extract_path) + background_tasks.add_task(import_poc_handle, file) + return {"message": "正在导入中", "code": 200} - for root, dirs, files in os.walk(extract_path): - for filename in files: - if filename.endswith('.yaml') or filename.endswith('.yml'): - file_path = os.path.join(root, filename) - yaml_files.append(file_path) - for yaml_file in yaml_files: - with open(yaml_file, 'r') as stream: - try: - data = yaml.safe_load(stream) - print(data["id"]) - except: - pass + +async def import_poc_handle(file): + logger.info("POC导入开始") + async for db in get_mongo_db(): + file_name = generate_random_string(5) + relative_path = f'file\\{file_name}.zip' + zip_file_path = os.path.join(os.getcwd(), relative_path) + with open(zip_file_path, "wb") as f: + f.write(await file.read()) + + yaml_files = [] + unzip_path = f'file\\{file_name}' + file_path = os.path.join(os.getcwd(), unzip_path) + extract_path = file_path + os.makedirs(extract_path, exist_ok=True) + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + for member in zip_ref.namelist(): + member_path = os.path.join(extract_path, member) + if not is_safe_path(extract_path, member_path): + logger.error("Unsafe file path detected in ZIP file") + return + zip_ref.extractall(extract_path) + + for root, dirs, files in os.walk(extract_path): + for filename in files: + if filename.endswith('.yaml'): + file_path = os.path.join(root, filename) + yaml_files.append(file_path) + hash_doc = await db.PocList.find({}, {"hash": 1, "_id": 0}).to_list(length=None) + hash_list = [item["hash"] for item in hash_doc] + success_num = 0 + error_num = 0 + repeat_num = 0 + insert_error_num = 0 + severity_dic = { + "critical": 6, + "high": 5, + "medium": 4, + "low": 3, + "info": 2, + "unkown": 1 + } + logger.info(f"共{len(yaml_files)}个POC") + poc_data_list = [] + for yaml_file in yaml_files: + with open(yaml_file, 'r', encoding='utf-8') as stream: + try: + file_content = stream.read() + data = yaml.safe_load(file_content) + name = data["id"] + if "severity" in data["info"]: + severity = data["info"]["severity"] + else: + severity = "unkown" + hash = calculate_md5_from_content(file_content) + if hash in hash_list: + repeat_num += 1 + continue + if severity in severity_dic: + severity = severity_dic[severity] + else: + severity = 1 + formatted_time = get_now_time() + data = { + "name": name, + "content": file_content, + "hash": hash, + "level": severity, + "time": formatted_time + } + poc_data_list.append(data) + except: + logger.info(f"POC导入 读取文件失败: {yaml_file}") + logger.error(traceback.format_exc()) + error_num += 1 + continue + if len(poc_data_list) != 0: + result = await db.PocList.insert_many(poc_data_list) + if result.inserted_ids: + success_num += len(result.inserted_ids) + await refresh_config('all', 'poc') + logger.info(f"POC更新成功: {success_num} 重复:{repeat_num} 失败: {error_num}") + try: + os.remove(zip_file_path) + shutil.rmtree(extract_path) + except: + logger.error(traceback.format_exc()) + logger.error("删除POC文件出错") + logger.info("POC导入结束") @router.get("/poc/data/all") @@ -194,6 +261,9 @@ async def add_poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = D hash_value = calculate_md5_from_content(content) level = request_data.get("level") formatted_time = get_now_time() + doc = await db.PocList.find_one({"hash": hash_value}, {"_id": 1}) + if doc: + return {"message": "POC已存在", "code": 500} # Insert data into the database result = await db.PocList.insert_one({ "name": name, diff --git a/core/util.py b/core/util.py index a478f08..5ec96ca 100644 --- a/core/util.py +++ b/core/util.py @@ -18,7 +18,7 @@ from core.db import get_mongo_db def calculate_md5_from_content(content): md5 = hashlib.md5() - md5.update(content.encode("utf-8")) # 将内容编码为 utf-8 后更新 MD5 + md5.update(content.encode("utf-8")) return md5.hexdigest() diff --git a/main.py b/main.py index effac83..3f3e088 100644 --- a/main.py +++ b/main.py @@ -261,4 +261,4 @@ def banner(): if __name__ == "__main__": banner() file_path = os.path.join(os.getcwd(), "file") - uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=True, reload_excludes=[file_path]) + uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=False, reload_excludes=[file_path])