add poc import

This commit is contained in:
Autumn.home 2024-07-18 23:50:50 +08:00
parent af1290df4e
commit cb9cb2b047
3 changed files with 101 additions and 31 deletions

View File

@ -3,11 +3,15 @@
# @auth: rainy-autumn@outlook.com # @auth: rainy-autumn@outlook.com
# @version: # @version:
import os import os
import shutil
import traceback
import yaml import yaml
from bson import ObjectId from bson import ObjectId
from fastapi import APIRouter, Depends, File, UploadFile from fastapi import APIRouter, Depends, File, UploadFile
from motor.motor_asyncio import AsyncIOMotorCursor from motor.motor_asyncio import AsyncIOMotorCursor
from starlette.background import BackgroundTasks
from api.users import verify_token from api.users import verify_token
from core.db import get_mongo_db from core.db import get_mongo_db
from pymongo import ASCENDING, DESCENDING from pymongo import ASCENDING, DESCENDING
@ -53,39 +57,102 @@ def is_safe_path(base_path, target_path):
@router.post("/poc/data/import") @router.post("/poc/data/import")
async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token)): async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token), background_tasks: BackgroundTasks = BackgroundTasks()):
if not file.filename.endswith('.zip'): if not file.filename.endswith('.zip'):
return {"message": "not zip", "code": 500} return {"message": "not zip", "code": 500}
file_name = generate_random_string(5)
relative_path = f'file\\{file_name}.zip'
zip_file_path = os.path.join(os.getcwd(), relative_path)
with open(zip_file_path, "wb") as f:
f.write(await file.read())
yaml_files = [] background_tasks.add_task(import_poc_handle, file)
unzip_path = f'file\\{file_name}' return {"message": "正在导入中", "code": 200}
file_path = os.path.join(os.getcwd(), unzip_path)
extract_path = file_path
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
for member in zip_ref.namelist():
member_path = os.path.join(extract_path, member)
if not is_safe_path(extract_path, member_path):
return {"message": "Unsafe file path detected in ZIP file", "code": 500}
zip_ref.extractall(extract_path)
for root, dirs, files in os.walk(extract_path):
for filename in files: async def import_poc_handle(file):
if filename.endswith('.yaml') or filename.endswith('.yml'): logger.info("POC导入开始")
file_path = os.path.join(root, filename) async for db in get_mongo_db():
yaml_files.append(file_path) file_name = generate_random_string(5)
for yaml_file in yaml_files: relative_path = f'file\\{file_name}.zip'
with open(yaml_file, 'r') as stream: zip_file_path = os.path.join(os.getcwd(), relative_path)
try: with open(zip_file_path, "wb") as f:
data = yaml.safe_load(stream) f.write(await file.read())
print(data["id"])
except: yaml_files = []
pass unzip_path = f'file\\{file_name}'
file_path = os.path.join(os.getcwd(), unzip_path)
extract_path = file_path
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
for member in zip_ref.namelist():
member_path = os.path.join(extract_path, member)
if not is_safe_path(extract_path, member_path):
logger.error("Unsafe file path detected in ZIP file")
return
zip_ref.extractall(extract_path)
for root, dirs, files in os.walk(extract_path):
for filename in files:
if filename.endswith('.yaml'):
file_path = os.path.join(root, filename)
yaml_files.append(file_path)
hash_doc = await db.PocList.find({}, {"hash": 1, "_id": 0}).to_list(length=None)
hash_list = [item["hash"] for item in hash_doc]
success_num = 0
error_num = 0
repeat_num = 0
insert_error_num = 0
severity_dic = {
"critical": 6,
"high": 5,
"medium": 4,
"low": 3,
"info": 2,
"unkown": 1
}
logger.info(f"{len(yaml_files)}个POC")
poc_data_list = []
for yaml_file in yaml_files:
with open(yaml_file, 'r', encoding='utf-8') as stream:
try:
file_content = stream.read()
data = yaml.safe_load(file_content)
name = data["id"]
if "severity" in data["info"]:
severity = data["info"]["severity"]
else:
severity = "unkown"
hash = calculate_md5_from_content(file_content)
if hash in hash_list:
repeat_num += 1
continue
if severity in severity_dic:
severity = severity_dic[severity]
else:
severity = 1
formatted_time = get_now_time()
data = {
"name": name,
"content": file_content,
"hash": hash,
"level": severity,
"time": formatted_time
}
poc_data_list.append(data)
except:
logger.info(f"POC导入 读取文件失败: {yaml_file}")
logger.error(traceback.format_exc())
error_num += 1
continue
if len(poc_data_list) != 0:
result = await db.PocList.insert_many(poc_data_list)
if result.inserted_ids:
success_num += len(result.inserted_ids)
await refresh_config('all', 'poc')
logger.info(f"POC更新成功: {success_num} 重复:{repeat_num} 失败: {error_num}")
try:
os.remove(zip_file_path)
shutil.rmtree(extract_path)
except:
logger.error(traceback.format_exc())
logger.error("删除POC文件出错")
logger.info("POC导入结束")
@router.get("/poc/data/all") @router.get("/poc/data/all")
@ -194,6 +261,9 @@ async def add_poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = D
hash_value = calculate_md5_from_content(content) hash_value = calculate_md5_from_content(content)
level = request_data.get("level") level = request_data.get("level")
formatted_time = get_now_time() formatted_time = get_now_time()
doc = await db.PocList.find_one({"hash": hash_value}, {"_id": 1})
if doc:
return {"message": "POC已存在", "code": 500}
# Insert data into the database # Insert data into the database
result = await db.PocList.insert_one({ result = await db.PocList.insert_one({
"name": name, "name": name,

View File

@ -18,7 +18,7 @@ from core.db import get_mongo_db
def calculate_md5_from_content(content): def calculate_md5_from_content(content):
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update(content.encode("utf-8")) # 将内容编码为 utf-8 后更新 MD5 md5.update(content.encode("utf-8"))
return md5.hexdigest() return md5.hexdigest()

View File

@ -261,4 +261,4 @@ def banner():
if __name__ == "__main__": if __name__ == "__main__":
banner() banner()
file_path = os.path.join(os.getcwd(), "file") file_path = os.path.join(os.getcwd(), "file")
uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=True, reload_excludes=[file_path]) uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=False, reload_excludes=[file_path])