异步请求工具类
封装了一个异步请求工具类,可以自定义请求头构造和请求返回解析,并且可以控制并发大小和批次大小。
还有升级空间,比如:请求失败重试、请求失败任务保存,请求失败任务恢复等功能。
import aiohttp
import asyncio
import json
import os
import time
from typing import Any, Callable, Dict, List, Optional
class AdvancedAsyncScraper:
def __init__(
self,
request_builder: Callable[[Dict], Dict[str, Any]],
response_handler: Callable[[Dict, Dict], Any],
task_list: List[Dict],
config: Dict[str, Any]
):
"""
增强版异步爬虫框架(支持复杂任务字典)
:param request_builder: 请求构造器 (task -> 请求配置字典)
:param response_handler: 响应处理器 (响应数据, task -> 处理结果)
:param task_list: 任务字典列表
:param config: 全局配置
"""
self.request_builder = request_builder
self.response_handler = response_handler
self.tasks = task_list
self.config = config
# 初始化输出目录
os.makedirs(config.get("output_dir", "results"), exist_ok=True)
async def process_task(
self,
session: aiohttp.ClientSession,
semaphore: asyncio.Semaphore,
task: Dict
):
"""处理单个任务的完整流程"""
task_id = task.get('task_id', str(task)[:50])
start_time = time.time()
try:
async with semaphore:
self._log(f"开始处理任务: {task_id}")
# 构造请求参数
request_config = self.request_builder(task)
# 执行请求
response_data = await self._execute_request(session, task, request_config)
if not response_data:
return
# 处理响应
self.response_handler(response_data, task)
# 记录耗时
cost = time.time() - start_time
self._log(f"任务完成: {task_id} | 耗时: {cost:.2f}s")
except Exception as e:
self._log(f"任务失败 [{task_id}]: {str(e)}")
async def run(self):
"""启动任务处理器"""
self._log(f"启动处理器,总任务数: {len(self.tasks)}")
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=False),
headers=self.config.get("default_headers")
) as session:
semaphore = asyncio.Semaphore(self.config["concurrency"])
batch_size = self.config.get("batch_size", 100)
# 分批处理任务
for i in range(0, len(self.tasks), batch_size):
batch = self.tasks[i:i + batch_size]
self._log(f"处理批次: {i + 1}-{i + len(batch)}")
await asyncio.gather(*[
self.process_task(session, semaphore, task)
for task in batch
])
if self.config.get("batch_delay"):
await asyncio.sleep(self.config["batch_delay"])
self._log("所有任务处理完成")
async def _execute_request(self, session, task, config):
"""执行实际请求"""
try:
async with session.request(
method=config["method"],
url=config["url"],
headers=config.get("headers"),
params=config.get("params"),
json=config.get("json"),
data=config.get("data"),
timeout=aiohttp.ClientTimeout(total=self.config["request_timeout"])
) as resp:
return await self._parse_response(resp, task)
except Exception as e:
self._log(f"请求异常 [{task.get('task_id')}]: {str(e)}")
return None
async def _parse_response(self, resp, task):
"""解析响应数据"""
try:
content = await resp.text()
if resp.content_type == 'application/json':
return json.loads(content)
return {"status": resp.status, "raw_data": content}
except Exception as e:
self._log(f"响应解析失败 [{task.get('task_id')}]: {str(e)}")
return None
def _log(self, message):
"""统一日志格式"""
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}")
# ----------------------------
# 使用示例
# ----------------------------
def complex_request_builder(task: Dict) -> Dict:
"""请求构造示例"""
# 根据不同任务类型构建请求
base_url = "https://www.baidu.com/"
# if task["type"] == "user_detail":
# return {
# "method": "GET",
# "url": f"{base_url}/users/{task['user_id']}",
# "params": {"fields": "all"}
# }
# elif task["type"] == "order_list":
# return {
# "method": "POST",
# "url": f"{base_url}/orders/search",
# "json": {
# "status": task["status"],
# "page": task.get("page", 1)
# }
# }
# else:
# raise ValueError("未知任务类型")
return {
"method": "GET",
"url": f"{base_url}?i={task['id']}"
}
def response_processor(data: Dict, task: Dict) -> Dict:
"""响应处理示例"""
print(f"任务:{task['id']} 完成: {data['status']}")
if __name__ == "__main__":
# 创建任务
task_pool = []
for i in range(10):
task_pool.append({
"task_id": i,
"id": i
})
# 全局配置
config = {
"concurrency": 3, # 并发
"request_timeout": 20, # Http超时
"batch_size": 10, # 批次大小
"batch_delay": 1, # 批次间延时
"default_headers": { # 请求头
"Authorization": "Bearer YOUR_TOKEN",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
}
}
# 初始化并运行
scraper = AdvancedAsyncScraper(
request_builder=complex_request_builder, # 请求生成器
response_handler=response_processor, # 响应处理器
task_list=task_pool, # 任务列表
config=config # 配置
)
try:
# win中必须使用 否则报错
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(scraper.run())
except KeyboardInterrupt:
print("\n[系统] 用户中断执行")