fix bug, add poc import

This commit is contained in:
Autumn.home 2024-07-17 21:39:09 +08:00
parent 689debfe87
commit af1290df4e
6 changed files with 70 additions and 13 deletions

View File

@ -44,7 +44,7 @@ async def asset_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Dep
try:
if len(APP) == 0:
collection = db["FingerprintRules"]
cursor = await collection.find({}, {"_id": 1, "name": 1})
cursor = collection.find({}, {"_id": 1, "name": 1})
async for document in cursor:
document['id'] = str(document['_id'])
del document['_id']
@ -114,7 +114,10 @@ async def asset_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Dep
tmp['products'] = tmp['products'] + technologies
if r['webfinger'] is not None:
for w in r['webfinger']:
tmp['products'].append(APP[w])
if w in APP:
tmp['products'].append(APP[w])
else:
tmp['products'].append(w)
result_list.append(tmp)
return {
"code": 200,

View File

@ -133,7 +133,7 @@ async def export_data(request_data: dict, db=Depends(get_mongo_db), _: dict = De
async def fetch_data(db, collection, query, quantity, project_list):
# 构造替换字段值的pipeline
branches = []
branches = [{"case": {"$eq": ["$project", ""]}, "then": ""}]
for new_value, original_value in project_list.items():
branches.append({"case": {"$eq": ["$project", original_value]}, "then": new_value})

View File

@ -2,8 +2,11 @@
# @name: poc_manage
# @auth: rainy-autumn@outlook.com
# @version:
import os
import yaml
from bson import ObjectId
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, File, UploadFile
from motor.motor_asyncio import AsyncIOMotorCursor
from api.users import verify_token
from core.db import get_mongo_db
@ -11,6 +14,7 @@ from pymongo import ASCENDING, DESCENDING
from loguru import logger
from core.redis_handler import refresh_config
from core.util import *
import zipfile
router = APIRouter()
@ -41,6 +45,49 @@ async def poc_data(request_data: dict, db=Depends(get_mongo_db), _: dict = Depen
return {"message": "error", "code": 500}
def is_safe_path(base_path, target_path):
# 计算规范化路径
abs_base_path = os.path.abspath(base_path)
abs_target_path = os.path.abspath(target_path)
return abs_target_path.startswith(abs_base_path)
@router.post("/poc/data/import")
async def poc_import(file: UploadFile = File(...), db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
if not file.filename.endswith('.zip'):
return {"message": "not zip", "code": 500}
file_name = generate_random_string(5)
relative_path = f'file\\{file_name}.zip'
zip_file_path = os.path.join(os.getcwd(), relative_path)
with open(zip_file_path, "wb") as f:
f.write(await file.read())
yaml_files = []
unzip_path = f'file\\{file_name}'
file_path = os.path.join(os.getcwd(), unzip_path)
extract_path = file_path
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
for member in zip_ref.namelist():
member_path = os.path.join(extract_path, member)
if not is_safe_path(extract_path, member_path):
return {"message": "Unsafe file path detected in ZIP file", "code": 500}
zip_ref.extractall(extract_path)
for root, dirs, files in os.walk(extract_path):
for filename in files:
if filename.endswith('.yaml') or filename.endswith('.yml'):
file_path = os.path.join(root, filename)
yaml_files.append(file_path)
for yaml_file in yaml_files:
with open(yaml_file, 'r') as stream:
try:
data = yaml.safe_load(stream)
print(data["id"])
except:
pass
@router.get("/poc/data/all")
async def poc_data(db=Depends(get_mongo_db), _: dict = Depends(verify_token)):
try:

View File

@ -123,12 +123,13 @@ async def update_project_count():
async def update_count(id):
query = {"project": {"$eq": id}}
total_count = await db.asset.count_documents(query)
update_document = {
"$set": {
"AssetCount": total_count
if total_count != 0:
update_document = {
"$set": {
"AssetCount": total_count
}
}
}
await db.project.update_one({"_id": ObjectId(id)}, update_document)
await db.project.update_one({"_id": ObjectId(id)}, update_document)
fetch_tasks = [update_count(r['id']) for r in results]

View File

@ -12,7 +12,8 @@ from loguru import logger
async def get_mongo_db():
client = AsyncIOMotorClient(f"mongodb://{DATABASE_USER}:{DATABASE_PASSWORD}@{MONGODB_IP}:{str(MONGODB_PORT)}",
client = AsyncIOMotorClient(f"mongodb://{DATABASE_USER}:{quote_plus(DATABASE_PASSWORD)}@{MONGODB_IP}:{str(MONGODB_PORT)}",
serverSelectionTimeoutMS=10000, unicode_decode_error_handler='ignore')
db = client[DATABASE_NAME]
try:

11
main.py
View File

@ -87,11 +87,15 @@ async def update():
logger.error("No DomainDic content to upload.")
# 更新敏感信息
await db.SensitiveRule.delete_many({})
sensitive_data = get_sensitive()
collection = db["SensitiveRule"]
if sensitive_data:
await collection.insert_many(sensitive_data)
for s in sensitive_data:
await collection.update_one(
{"name": s['name']},
{"$set": s},
upsert=True
)
await db.config.update_one({"name": "version"}, {"$set": {"update": True, "version": float(VERSION)}})
@ -256,4 +260,5 @@ def banner():
if __name__ == "__main__":
banner()
uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=True)
file_path = os.path.join(os.getcwd(), "file")
uvicorn.run("main:app", host="0.0.0.0", port=8082, reload=True, reload_excludes=[file_path])