ScopeSentry/api/poc.py

311 lines
11 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding:utf-8 -*-  
# @name: poc_manage
# @auth: rainy-autumn@outlook.com
# @version:
import os
import shutil
import traceback
import yaml
from bson import ObjectId
from fastapi import APIRouter, Depends, File, UploadFile
from motor.motor_asyncio import AsyncIOMotorCursor
from starlette.background import BackgroundTasks
from api.users import verify_token
from core.db import get_mongo_db
from pymongo import ASCENDING, DESCENDING
from loguru import logger
from core.redis_handler import refresh_config
from core.util import *
import zipfile
router = APIRouter()
@router.post("/poc/data")
async def poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
try:
search_query = request_data.get("search", "")
page_index = request_data.get("pageIndex", 1)
page_size = request_data.get("pageSize", 10)
query = {"name": {"$regex": search_query, "$options": "i"}}
# Get the total count of documents matching the search criteria
total_count = await db.PocList.count_documents(query)
# Perform pagination query and sort by time
cursor: AsyncIOMotorCursor = db.PocList.find(query, {"_id": 0, "id": {"$toString": "$_id"}, "name": 1, "level": 1, "time": 1}).sort([("time", DESCENDING)]).skip((page_index - 1) * page_size).limit(page_size)
result = await cursor.to_list(length=None)
return {
"code": 200,
"data": {
'list': result,
'total': total_count
}
}
except Exception as e:
logger.error(str(e))
# Handle exceptions as needed
return {"message": "error", "code": 500}
def is_safe_path(base_path, target_path):
# 计算规范化路径
abs_base_path = os.path.abspath(base_path)
abs_target_path = os.path.abspath(target_path)
return abs_target_path.startswith(abs_base_path)
@router.post("/poc/data/import")
async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token), background_tasks: BackgroundTasks = BackgroundTasks()):
if not file.filename.endswith('.zip'):
return {"message": "not zip", "code": 500}
background_tasks.add_task(import_poc_handle, file)
return {"message": "正在导入中", "code": 200}
async def import_poc_handle(file):
logger.info("POC导入开始")
async for db in get_mongo_db():
file_name = generate_random_string(5)
relative_path = f'file\\{file_name}.zip'
zip_file_path = os.path.join(os.getcwd(), relative_path)
with open(zip_file_path, "wb") as f:
f.write(await file.read())
yaml_files = []
unzip_path = f'file\\{file_name}'
file_path = os.path.join(os.getcwd(), unzip_path)
extract_path = file_path
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
for member in zip_ref.namelist():
member_path = os.path.join(extract_path, member)
if not is_safe_path(extract_path, member_path):
logger.error("Unsafe file path detected in ZIP file")
return
zip_ref.extractall(extract_path)
for root, dirs, files in os.walk(extract_path):
for filename in files:
if filename.endswith('.yaml'):
file_path = os.path.join(root, filename)
yaml_files.append(file_path)
hash_doc = await db.PocList.find({}, {"hash": 1, "_id": 0}).to_list(length=None)
hash_list = [item["hash"] for item in hash_doc]
success_num = 0
error_num = 0
repeat_num = 0
insert_error_num = 0
severity_dic = {
"critical": 6,
"high": 5,
"medium": 4,
"low": 3,
"info": 2,
"unkown": 1
}
logger.info(f"{len(yaml_files)}个POC")
poc_data_list = []
for yaml_file in yaml_files:
with open(yaml_file, 'r', encoding='utf-8') as stream:
try:
file_content = stream.read()
data = yaml.safe_load(file_content)
name = data["id"]
if "severity" in data["info"]:
severity = data["info"]["severity"]
else:
severity = "unkown"
hash = calculate_md5_from_content(file_content)
if hash in hash_list:
repeat_num += 1
continue
if severity in severity_dic:
severity = severity_dic[severity]
else:
severity = 1
formatted_time = get_now_time()
data = {
"name": name,
"content": file_content,
"hash": hash,
"level": severity,
"time": formatted_time
}
poc_data_list.append(data)
except:
logger.info(f"POC导入 读取文件失败: {yaml_file}")
logger.error(traceback.format_exc())
error_num += 1
continue
if len(poc_data_list) != 0:
result = await db.PocList.insert_many(poc_data_list)
if result.inserted_ids:
success_num += len(result.inserted_ids)
await refresh_config('all', 'poc')
logger.info(f"POC更新成功: {success_num} 重复:{repeat_num} 失败: {error_num}")
try:
os.remove(zip_file_path)
shutil.rmtree(extract_path)
except:
logger.error(traceback.format_exc())
logger.error("删除POC文件出错")
logger.info("POC导入结束")
@router.get("/poc/data/all")
async def poc_data(db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
try:
cursor: AsyncIOMotorCursor = db.PocList.find({}, {"id": {"$toString": "$_id"}, "name": 1, "time": -1, "_id": 0}).sort([("time", DESCENDING)])
result = await cursor.to_list(length=None)
return {
"code": 200,
"data": {
'list': result
}
}
except Exception as e:
logger.error(str(e))
# Handle exceptions as needed
return {"message": "error", "code": 500}
@router.post("/poc/content")
async def poc_content(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
try:
# Get the ID from the request data
poc_id = request_data.get("id")
# Check if ID is provided
if not poc_id:
return {"message": "ID is missing in the request data", "code": 400}
# Query the database for content based on ID
query = {"_id": ObjectId(poc_id)}
doc = await db.PocList.find_one(query)
if not doc:
return {"message": "Content not found for the provided ID", "code": 404}
# Extract the content
content = doc.get("content", "")
return {"code": 200, "data": {"content": content}}
except Exception as e:
logger.error(str(e))
# Handle exceptions as needed
return {"message": "error", "code": 500}
@router.post("/poc/update")
async def update_poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
try:
# Get the ID from the request data
poc_id = request_data.get("id")
# Check if ID is provided
if not poc_id:
return {"message": "ID is missing in the request data", "code": 400}
# Check if data to update is provided
if not request_data:
return {"message": "Data to update is missing in the request", "code": 400}
# Extract individual fields from the request data
name = request_data.get("name")
content = request_data.get("content")
hash_value = calculate_md5_from_content(content)
level = request_data.get("level")
# Prepare the update document
update_document = {
"$set": {
"name": name,
"content": content,
"hash": hash_value,
"level": level
}
}
# Remove the ID from the request data to prevent it from being updated
del request_data["id"]
# Update data in the database
result = await db.PocList.update_one({"_id": ObjectId(poc_id)}, update_document)
# Check if the update was successful
if result:
await refresh_config('all', 'poc')
return {"message": "Data updated successfully", "code": 200}
else:
return {"message": "Failed to update data", "code": 404}
except Exception as e:
logger.error(str(e))
# Handle exceptions as needed
return {"message": "error", "code": 500}
@router.post("/poc/add")
async def add_poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
try:
# Check if data to add is provided
if not request_data:
return {"message": "Data to add is missing in the request", "code": 400}
# Extract individual fields from the request data
name = request_data.get("name")
content = request_data.get("content")
hash_value = calculate_md5_from_content(content)
level = request_data.get("level")
formatted_time = get_now_time()
doc = await db.PocList.find_one({"hash": hash_value}, {"_id": 1})
if doc:
return {"message": "POC已存在", "code": 500}
# Insert data into the database
result = await db.PocList.insert_one({
"name": name,
"content": content,
"hash": hash_value,
"level": level,
"time": formatted_time
})
# Check if the insertion was successful
if result.inserted_id:
await refresh_config('all', 'poc')
return {"message": "Data added successfully", "code": 200}
else:
return {"message": "Failed to add data", "code": 400}
except Exception as e:
logger.error(str(e))
# Handle exceptions as needed
return {"message": "error", "code": 500}
@router.post("/poc/delete")
async def delete_poc_rules(request_data: dict, db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
try:
# Extract the list of IDs from the request_data dictionary
poc_ids = request_data.get("ids", [])
# Convert the provided rule_ids to ObjectId
obj_ids = [ObjectId(poc_id) for poc_id in poc_ids]
# Delete the SensitiveRule documents based on the provided IDs
result = await db.PocList.delete_many({"_id": {"$in": obj_ids}})
# Check if the deletion was successful
if result.deleted_count > 0:
return {"code": 200, "message": "Poc deleted successfully"}
else:
return {"code": 404, "message": "Poc not found"}
except Exception as e:
logger.error(str(e))
# Handle exceptions as needed
return {"message": "error", "code": 500}