2024-06-05 13:39:34 +00:00
|
|
|
|
# -*- coding:utf-8 -*-
|
|
|
|
|
# @name: poc_manage
|
|
|
|
|
# @auth: rainy-autumn@outlook.com
|
|
|
|
|
# @version:
|
2024-07-17 13:39:09 +00:00
|
|
|
|
import os
|
2024-07-18 15:50:50 +00:00
|
|
|
|
import shutil
|
|
|
|
|
import traceback
|
2024-07-17 13:39:09 +00:00
|
|
|
|
|
|
|
|
|
import yaml
|
2024-06-05 13:39:34 +00:00
|
|
|
|
from bson import ObjectId
|
2024-07-17 13:39:09 +00:00
|
|
|
|
from fastapi import APIRouter, Depends, File, UploadFile
|
2024-06-05 13:39:34 +00:00
|
|
|
|
from motor.motor_asyncio import AsyncIOMotorCursor
|
2024-07-18 15:50:50 +00:00
|
|
|
|
from starlette.background import BackgroundTasks
|
|
|
|
|
|
2024-06-05 13:39:34 +00:00
|
|
|
|
from api.users import verify_token
|
|
|
|
|
from core.db import get_mongo_db
|
|
|
|
|
from pymongo import ASCENDING, DESCENDING
|
|
|
|
|
from loguru import logger
|
|
|
|
|
from core.redis_handler import refresh_config
|
|
|
|
|
from core.util import *
|
2024-07-17 13:39:09 +00:00
|
|
|
|
import zipfile
|
2024-06-05 13:39:34 +00:00
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/poc/data")
|
|
|
|
|
async def poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
search_query = request_data.get("search", "")
|
|
|
|
|
page_index = request_data.get("pageIndex", 1)
|
|
|
|
|
page_size = request_data.get("pageSize", 10)
|
|
|
|
|
query = {"name": {"$regex": search_query, "$options": "i"}}
|
|
|
|
|
|
|
|
|
|
# Get the total count of documents matching the search criteria
|
|
|
|
|
total_count = await db.PocList.count_documents(query)
|
|
|
|
|
# Perform pagination query and sort by time
|
2024-07-14 10:09:30 +00:00
|
|
|
|
cursor: AsyncIOMotorCursor = db.PocList.find(query, {"_id": 0, "id": {"$toString": "$_id"}, "name": 1, "level": 1, "time": 1}).sort([("time", DESCENDING)]).skip((page_index - 1) * page_size).limit(page_size)
|
2024-06-05 13:39:34 +00:00
|
|
|
|
result = await cursor.to_list(length=None)
|
|
|
|
|
return {
|
|
|
|
|
"code": 200,
|
|
|
|
|
"data": {
|
|
|
|
|
'list': result,
|
|
|
|
|
'total': total_count
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
# Handle exceptions as needed
|
|
|
|
|
return {"message": "error", "code": 500}
|
|
|
|
|
|
|
|
|
|
|
2024-07-17 13:39:09 +00:00
|
|
|
|
def is_safe_path(base_path, target_path):
|
|
|
|
|
# 计算规范化路径
|
|
|
|
|
abs_base_path = os.path.abspath(base_path)
|
|
|
|
|
abs_target_path = os.path.abspath(target_path)
|
|
|
|
|
return abs_target_path.startswith(abs_base_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/poc/data/import")
|
2024-07-18 15:50:50 +00:00
|
|
|
|
async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token), background_tasks: BackgroundTasks = BackgroundTasks()):
|
2024-07-17 13:39:09 +00:00
|
|
|
|
if not file.filename.endswith('.zip'):
|
|
|
|
|
return {"message": "not zip", "code": 500}
|
2024-07-18 15:50:50 +00:00
|
|
|
|
|
|
|
|
|
background_tasks.add_task(import_poc_handle, file)
|
|
|
|
|
return {"message": "正在导入中", "code": 200}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def import_poc_handle(file):
|
|
|
|
|
logger.info("POC导入开始")
|
|
|
|
|
async for db in get_mongo_db():
|
|
|
|
|
file_name = generate_random_string(5)
|
|
|
|
|
relative_path = f'file\\{file_name}.zip'
|
|
|
|
|
zip_file_path = os.path.join(os.getcwd(), relative_path)
|
|
|
|
|
with open(zip_file_path, "wb") as f:
|
|
|
|
|
f.write(await file.read())
|
|
|
|
|
|
|
|
|
|
yaml_files = []
|
|
|
|
|
unzip_path = f'file\\{file_name}'
|
|
|
|
|
file_path = os.path.join(os.getcwd(), unzip_path)
|
|
|
|
|
extract_path = file_path
|
|
|
|
|
os.makedirs(extract_path, exist_ok=True)
|
|
|
|
|
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
|
|
|
|
for member in zip_ref.namelist():
|
|
|
|
|
member_path = os.path.join(extract_path, member)
|
|
|
|
|
if not is_safe_path(extract_path, member_path):
|
|
|
|
|
logger.error("Unsafe file path detected in ZIP file")
|
|
|
|
|
return
|
|
|
|
|
zip_ref.extractall(extract_path)
|
|
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(extract_path):
|
|
|
|
|
for filename in files:
|
|
|
|
|
if filename.endswith('.yaml'):
|
|
|
|
|
file_path = os.path.join(root, filename)
|
|
|
|
|
yaml_files.append(file_path)
|
|
|
|
|
hash_doc = await db.PocList.find({}, {"hash": 1, "_id": 0}).to_list(length=None)
|
|
|
|
|
hash_list = [item["hash"] for item in hash_doc]
|
|
|
|
|
success_num = 0
|
|
|
|
|
error_num = 0
|
|
|
|
|
repeat_num = 0
|
|
|
|
|
insert_error_num = 0
|
|
|
|
|
severity_dic = {
|
|
|
|
|
"critical": 6,
|
|
|
|
|
"high": 5,
|
|
|
|
|
"medium": 4,
|
|
|
|
|
"low": 3,
|
|
|
|
|
"info": 2,
|
|
|
|
|
"unkown": 1
|
|
|
|
|
}
|
|
|
|
|
logger.info(f"共{len(yaml_files)}个POC")
|
|
|
|
|
poc_data_list = []
|
|
|
|
|
for yaml_file in yaml_files:
|
|
|
|
|
with open(yaml_file, 'r', encoding='utf-8') as stream:
|
|
|
|
|
try:
|
|
|
|
|
file_content = stream.read()
|
|
|
|
|
data = yaml.safe_load(file_content)
|
|
|
|
|
name = data["id"]
|
|
|
|
|
if "severity" in data["info"]:
|
|
|
|
|
severity = data["info"]["severity"]
|
|
|
|
|
else:
|
|
|
|
|
severity = "unkown"
|
|
|
|
|
hash = calculate_md5_from_content(file_content)
|
|
|
|
|
if hash in hash_list:
|
|
|
|
|
repeat_num += 1
|
|
|
|
|
continue
|
|
|
|
|
if severity in severity_dic:
|
|
|
|
|
severity = severity_dic[severity]
|
|
|
|
|
else:
|
|
|
|
|
severity = 1
|
|
|
|
|
formatted_time = get_now_time()
|
|
|
|
|
data = {
|
|
|
|
|
"name": name,
|
|
|
|
|
"content": file_content,
|
|
|
|
|
"hash": hash,
|
|
|
|
|
"level": severity,
|
|
|
|
|
"time": formatted_time
|
|
|
|
|
}
|
|
|
|
|
poc_data_list.append(data)
|
|
|
|
|
except:
|
|
|
|
|
logger.info(f"POC导入 读取文件失败: {yaml_file}")
|
|
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
error_num += 1
|
|
|
|
|
continue
|
|
|
|
|
if len(poc_data_list) != 0:
|
|
|
|
|
result = await db.PocList.insert_many(poc_data_list)
|
|
|
|
|
if result.inserted_ids:
|
|
|
|
|
success_num += len(result.inserted_ids)
|
|
|
|
|
await refresh_config('all', 'poc')
|
|
|
|
|
logger.info(f"POC更新成功: {success_num} 重复:{repeat_num} 失败: {error_num}")
|
|
|
|
|
try:
|
|
|
|
|
os.remove(zip_file_path)
|
|
|
|
|
shutil.rmtree(extract_path)
|
|
|
|
|
except:
|
|
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
logger.error("删除POC文件出错")
|
|
|
|
|
logger.info("POC导入结束")
|
2024-07-17 13:39:09 +00:00
|
|
|
|
|
|
|
|
|
|
2024-06-05 13:39:34 +00:00
|
|
|
|
@router.get("/poc/data/all")
|
|
|
|
|
async def poc_data(db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
cursor: AsyncIOMotorCursor = db.PocList.find({}, {"id": {"$toString": "$_id"}, "name": 1, "time": -1, "_id": 0}).sort([("time", DESCENDING)])
|
|
|
|
|
result = await cursor.to_list(length=None)
|
|
|
|
|
return {
|
|
|
|
|
"code": 200,
|
|
|
|
|
"data": {
|
|
|
|
|
'list': result
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
# Handle exceptions as needed
|
|
|
|
|
return {"message": "error", "code": 500}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/poc/content")
|
|
|
|
|
async def poc_content(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
# Get the ID from the request data
|
|
|
|
|
poc_id = request_data.get("id")
|
|
|
|
|
|
|
|
|
|
# Check if ID is provided
|
|
|
|
|
if not poc_id:
|
|
|
|
|
return {"message": "ID is missing in the request data", "code": 400}
|
|
|
|
|
|
|
|
|
|
# Query the database for content based on ID
|
|
|
|
|
query = {"_id": ObjectId(poc_id)}
|
|
|
|
|
doc = await db.PocList.find_one(query)
|
|
|
|
|
|
|
|
|
|
if not doc:
|
|
|
|
|
return {"message": "Content not found for the provided ID", "code": 404}
|
|
|
|
|
|
|
|
|
|
# Extract the content
|
|
|
|
|
content = doc.get("content", "")
|
|
|
|
|
|
|
|
|
|
return {"code": 200, "data": {"content": content}}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
# Handle exceptions as needed
|
|
|
|
|
return {"message": "error", "code": 500}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/poc/update")
|
|
|
|
|
async def update_poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
# Get the ID from the request data
|
|
|
|
|
poc_id = request_data.get("id")
|
|
|
|
|
|
|
|
|
|
# Check if ID is provided
|
|
|
|
|
if not poc_id:
|
|
|
|
|
return {"message": "ID is missing in the request data", "code": 400}
|
|
|
|
|
|
|
|
|
|
# Check if data to update is provided
|
|
|
|
|
if not request_data:
|
|
|
|
|
return {"message": "Data to update is missing in the request", "code": 400}
|
|
|
|
|
|
|
|
|
|
# Extract individual fields from the request data
|
|
|
|
|
name = request_data.get("name")
|
|
|
|
|
content = request_data.get("content")
|
|
|
|
|
hash_value = calculate_md5_from_content(content)
|
|
|
|
|
level = request_data.get("level")
|
|
|
|
|
|
|
|
|
|
# Prepare the update document
|
|
|
|
|
update_document = {
|
|
|
|
|
"$set": {
|
|
|
|
|
"name": name,
|
|
|
|
|
"content": content,
|
|
|
|
|
"hash": hash_value,
|
|
|
|
|
"level": level
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Remove the ID from the request data to prevent it from being updated
|
|
|
|
|
del request_data["id"]
|
|
|
|
|
|
|
|
|
|
# Update data in the database
|
|
|
|
|
result = await db.PocList.update_one({"_id": ObjectId(poc_id)}, update_document)
|
|
|
|
|
# Check if the update was successful
|
|
|
|
|
if result:
|
|
|
|
|
await refresh_config('all', 'poc')
|
|
|
|
|
return {"message": "Data updated successfully", "code": 200}
|
|
|
|
|
else:
|
|
|
|
|
return {"message": "Failed to update data", "code": 404}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
# Handle exceptions as needed
|
|
|
|
|
return {"message": "error", "code": 500}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/poc/add")
|
|
|
|
|
async def add_poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
# Check if data to add is provided
|
|
|
|
|
if not request_data:
|
|
|
|
|
return {"message": "Data to add is missing in the request", "code": 400}
|
|
|
|
|
|
|
|
|
|
# Extract individual fields from the request data
|
|
|
|
|
name = request_data.get("name")
|
|
|
|
|
content = request_data.get("content")
|
|
|
|
|
hash_value = calculate_md5_from_content(content)
|
|
|
|
|
level = request_data.get("level")
|
|
|
|
|
formatted_time = get_now_time()
|
2024-07-18 15:50:50 +00:00
|
|
|
|
doc = await db.PocList.find_one({"hash": hash_value}, {"_id": 1})
|
|
|
|
|
if doc:
|
|
|
|
|
return {"message": "POC已存在", "code": 500}
|
2024-06-05 13:39:34 +00:00
|
|
|
|
# Insert data into the database
|
|
|
|
|
result = await db.PocList.insert_one({
|
|
|
|
|
"name": name,
|
|
|
|
|
"content": content,
|
|
|
|
|
"hash": hash_value,
|
|
|
|
|
"level": level,
|
|
|
|
|
"time": formatted_time
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# Check if the insertion was successful
|
|
|
|
|
if result.inserted_id:
|
|
|
|
|
await refresh_config('all', 'poc')
|
|
|
|
|
return {"message": "Data added successfully", "code": 200}
|
|
|
|
|
else:
|
|
|
|
|
return {"message": "Failed to add data", "code": 400}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
# Handle exceptions as needed
|
|
|
|
|
return {"message": "error", "code": 500}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/poc/delete")
|
|
|
|
|
async def delete_poc_rules(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
# Extract the list of IDs from the request_data dictionary
|
|
|
|
|
poc_ids = request_data.get("ids", [])
|
|
|
|
|
|
|
|
|
|
# Convert the provided rule_ids to ObjectId
|
|
|
|
|
obj_ids = [ObjectId(poc_id) for poc_id in poc_ids]
|
|
|
|
|
|
|
|
|
|
# Delete the SensitiveRule documents based on the provided IDs
|
|
|
|
|
result = await db.PocList.delete_many({"_id": {"$in": obj_ids}})
|
|
|
|
|
|
|
|
|
|
# Check if the deletion was successful
|
|
|
|
|
if result.deleted_count > 0:
|
|
|
|
|
return {"code": 200, "message": "Poc deleted successfully"}
|
|
|
|
|
else:
|
|
|
|
|
return {"code": 404, "message": "Poc not found"}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
# Handle exceptions as needed
|
|
|
|
|
return {"message": "error", "code": 500}
|