add poc import
This commit is contained in:
parent
af1290df4e
commit
cb9cb2b047
84
api/poc.py
84
api/poc.py
|
@ -3,11 +3,15 @@
|
||||||
# @auth: rainy-autumn@outlook.com
|
# @auth: rainy-autumn@outlook.com
|
||||||
# @version:
|
# @version:
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
import traceback
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from fastapi import APIRouter, Depends, File, UploadFile
|
from fastapi import APIRouter, Depends, File, UploadFile
|
||||||
from motor.motor_asyncio import AsyncIOMotorCursor
|
from motor.motor_asyncio import AsyncIOMotorCursor
|
||||||
|
from starlette.background import BackgroundTasks
|
||||||
|
|
||||||
from api.users import verify_token
|
from api.users import verify_token
|
||||||
from core.db import get_mongo_db
|
from core.db import get_mongo_db
|
||||||
from pymongo import ASCENDING, DESCENDING
|
from pymongo import ASCENDING, DESCENDING
|
||||||
|
@ -53,9 +57,17 @@ def is_safe_path(base_path, target_path):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/poc/data/import")
|
@router.post("/poc/data/import")
|
||||||
async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
|
async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
if not file.filename.endswith('.zip'):
|
if not file.filename.endswith('.zip'):
|
||||||
return {"message": "not zip", "code": 500}
|
return {"message": "not zip", "code": 500}
|
||||||
|
|
||||||
|
background_tasks.add_task(import_poc_handle, file)
|
||||||
|
return {"message": "正在导入中", "code": 200}
|
||||||
|
|
||||||
|
|
||||||
|
async def import_poc_handle(file):
|
||||||
|
logger.info("POC导入开始")
|
||||||
|
async for db in get_mongo_db():
|
||||||
file_name = generate_random_string(5)
|
file_name = generate_random_string(5)
|
||||||
relative_path = f'file\\{file_name}.zip'
|
relative_path = f'file\\{file_name}.zip'
|
||||||
zip_file_path = os.path.join(os.getcwd(), relative_path)
|
zip_file_path = os.path.join(os.getcwd(), relative_path)
|
||||||
|
@ -71,21 +83,76 @@ async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _:
|
||||||
for member in zip_ref.namelist():
|
for member in zip_ref.namelist():
|
||||||
member_path = os.path.join(extract_path, member)
|
member_path = os.path.join(extract_path, member)
|
||||||
if not is_safe_path(extract_path, member_path):
|
if not is_safe_path(extract_path, member_path):
|
||||||
return {"message": "Unsafe file path detected in ZIP file", "code": 500}
|
logger.error("Unsafe file path detected in ZIP file")
|
||||||
|
return
|
||||||
zip_ref.extractall(extract_path)
|
zip_ref.extractall(extract_path)
|
||||||
|
|
||||||
for root, dirs, files in os.walk(extract_path):
|
for root, dirs, files in os.walk(extract_path):
|
||||||
for filename in files:
|
for filename in files:
|
||||||
if filename.endswith('.yaml') or filename.endswith('.yml'):
|
if filename.endswith('.yaml'):
|
||||||
file_path = os.path.join(root, filename)
|
file_path = os.path.join(root, filename)
|
||||||
yaml_files.append(file_path)
|
yaml_files.append(file_path)
|
||||||
|
hash_doc = await db.PocList.find({}, {"hash": 1, "_id": 0}).to_list(length=None)
|
||||||
|
hash_list = [item["hash"] for item in hash_doc]
|
||||||
|
success_num = 0
|
||||||
|
error_num = 0
|
||||||
|
repeat_num = 0
|
||||||
|
insert_error_num = 0
|
||||||
|
severity_dic = {
|
||||||
|
"critical": 6,
|
||||||
|
"high": 5,
|
||||||
|
"medium": 4,
|
||||||
|
"low": 3,
|
||||||
|
"info": 2,
|
||||||
|
"unkown": 1
|
||||||
|
}
|
||||||
|
logger.info(f"共{len(yaml_files)}个POC")
|
||||||
|
poc_data_list = []
|
||||||
for yaml_file in yaml_files:
|
for yaml_file in yaml_files:
|
||||||
with open(yaml_file, 'r') as stream:
|
with open(yaml_file, 'r', encoding='utf-8') as stream:
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(stream)
|
file_content = stream.read()
|
||||||
print(data["id"])
|
data = yaml.safe_load(file_content)
|
||||||
|
name = data["id"]
|
||||||
|
if "severity" in data["info"]:
|
||||||
|
severity = data["info"]["severity"]
|
||||||
|
else:
|
||||||
|
severity = "unkown"
|
||||||
|
hash = calculate_md5_from_content(file_content)
|
||||||
|
if hash in hash_list:
|
||||||
|
repeat_num += 1
|
||||||
|
continue
|
||||||
|
if severity in severity_dic:
|
||||||
|
severity = severity_dic[severity]
|
||||||
|
else:
|
||||||
|
severity = 1
|
||||||
|
formatted_time = get_now_time()
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"content": file_content,
|
||||||
|
"hash": hash,
|
||||||
|
"level": severity,
|
||||||
|
"time": formatted_time
|
||||||
|
}
|
||||||
|
poc_data_list.append(data)
|
||||||
except:
|
except:
|
||||||
pass
|
logger.info(f"POC导入 读取文件失败: {yaml_file}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
error_num += 1
|
||||||
|
continue
|
||||||
|
if len(poc_data_list) != 0:
|
||||||
|
result = await db.PocList.insert_many(poc_data_list)
|
||||||
|
if result.inserted_ids:
|
||||||
|
success_num += len(result.inserted_ids)
|
||||||
|
await refresh_config('all', 'poc')
|
||||||
|
logger.info(f"POC更新成功: {success_num} 重复:{repeat_num} 失败: {error_num}")
|
||||||
|
try:
|
||||||
|
os.remove(zip_file_path)
|
||||||
|
shutil.rmtree(extract_path)
|
||||||
|
except:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error("删除POC文件出错")
|
||||||
|
logger.info("POC导入结束")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/poc/data/all")
|
@router.get("/poc/data/all")
|
||||||
|
@ -194,6 +261,9 @@ async def add_poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = D
|
||||||
hash_value = calculate_md5_from_content(content)
|
hash_value = calculate_md5_from_content(content)
|
||||||
level = request_data.get("level")
|
level = request_data.get("level")
|
||||||
formatted_time = get_now_time()
|
formatted_time = get_now_time()
|
||||||
|
doc = await db.PocList.find_one({"hash": hash_value}, {"_id": 1})
|
||||||
|
if doc:
|
||||||
|
return {"message": "POC已存在", "code": 500}
|
||||||
# Insert data into the database
|
# Insert data into the database
|
||||||
result = await db.PocList.insert_one({
|
result = await db.PocList.insert_one({
|
||||||
"name": name,
|
"name": name,
|
||||||
|
|
|
@ -18,7 +18,7 @@ from core.db import get_mongo_db
|
||||||
|
|
||||||
def calculate_md5_from_content(content):
|
def calculate_md5_from_content(content):
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
md5.update(content.encode("utf-8")) # 将内容编码为 utf-8 后更新 MD5
|
md5.update(content.encode("utf-8"))
|
||||||
return md5.hexdigest()
|
return md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
|
2
main.py
2
main.py
|
@ -261,4 +261,4 @@ def banner():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
banner()
|
banner()
|
||||||
file_path = os.path.join(os.getcwd(), "file")
|
file_path = os.path.join(os.getcwd(), "file")
|
||||||
uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=True, reload_excludes=[file_path])
|
uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=False, reload_excludes=[file_path])
|
||||||
|
|
Loading…
Reference in New Issue