diff --git a/core/db.py b/core/db.py index fa643c6..1c95cfa 100644 --- a/core/db.py +++ b/core/db.py @@ -40,17 +40,17 @@ async def create_database(): exit(1) # 获取数据库列表 database_names = await client.list_database_names() - + db = client[DATABASE_NAME] # 如果数据库不存在,创建数据库 if DATABASE_NAME not in database_names: # 在数据库中创建一个集合,比如名为 "user" - collection = client[DATABASE_NAME]["user"] + collection = db["user"] # 用户数据 await collection.insert_one({"username": "ScopeSentry", 'password': 'b0ce71fcbed8a6ca579d52800145119cc7d999dc8651b62dfc1ced9a984e6e64'}) - collection = client[DATABASE_NAME]["config"] + collection = db["config"] # 系统配置 await collection.insert_one( {"name": "timezone", 'value': 'Asia/Shanghai', 'type': 'system'}) @@ -73,7 +73,7 @@ async def create_database(): # time_now = utc_now.astimezone(SHA_TZ) # formatted_time = time_now.strftime("%Y-%m-%d %H:%M:%S") # subfinder配置 - collection = client[DATABASE_NAME]["config"] + collection = db["config"] # 插入一条数据 await collection.insert_one( {"name": "SubfinderApiConfig", 'value': subfinderApiConfig, 'type': 'subfinder'}) @@ -83,7 +83,7 @@ async def create_database(): # await collection.insert_one( # {"name": "DirDic", 'value': dirDict, 'type': 'dirDict'}) # 目录扫描字典 - fs = AsyncIOMotorGridFSBucket(client) + fs = AsyncIOMotorGridFSBucket(db) content = get_dirDict() if content: byte_content = content.encode('utf-8') @@ -107,47 +107,46 @@ async def create_database(): # await collection.insert_one( # {"name": "DomainDic", 'value': domainDict, 'type': 'domainDict'}) sensitive_data = get_sensitive() - collection = client[DATABASE_NAME]["SensitiveRule"] + collection = db["SensitiveRule"] if sensitive_data: await collection.insert_many(sensitive_data) - collection = client[DATABASE_NAME]["ScheduledTasks"] + collection = db["ScheduledTasks"] await collection.insert_one( {"id": "page_monitoring", "name": "Page Monitoring", 'hour': 24, 'node': [], 'allNode': True, 'type': 'Page Monitoring', 'state': True}) - collection = client[DATABASE_NAME] - await collection.create_collection("notification") + await db.create_collection("notification") - collection = client[DATABASE_NAME]["PortDict"] + collection = db["PortDict"] await collection.insert_many(portDic) - collection = client[DATABASE_NAME]["PocList"] + collection = db["PocList"] pocData = get_poc() await collection.insert_many(pocData) - collection = client[DATABASE_NAME]["project"] + collection = db["project"] project_data, target_data = get_project_data() await collection.insert_many(project_data) - collection = client[DATABASE_NAME]["ProjectTargetData"] + collection = db["ProjectTargetData"] await collection.insert_many(target_data) - collection = client[DATABASE_NAME]["FingerprintRules"] + collection = db["FingerprintRules"] fingerprint = get_finger() await collection.insert_many(fingerprint) else: - collection = client[DATABASE_NAME]["config"] + collection = db["config"] result = await collection.find_one({"name": "timezone"}) set_timezone(result.get('value', 'Asia/Shanghai')) - collection = client[DATABASE_NAME]["ScheduledTasks"] + collection = db["ScheduledTasks"] result = await collection.find_one({"id": "page_monitoring"}) if not result: await collection.insert_one( {"id": "page_monitoring", "name": "Page Monitoring", 'hour': 24, 'type': 'Page Monitoring', 'state': True}) - await get_fingerprint(client[DATABASE_NAME]) - # await get_sens_rule(client[DATABASE_NAME]) - await get_project(client[DATABASE_NAME]) + await get_fingerprint(db) + # await get_sens_rule(db) + await get_project(db) except Exception as e: # 处理异常 logger.error(f"Error creating database: {e}")