异步请求工具类

封装了一个异步请求工具类,可以自定义请求头构造和请求返回解析,并且可以控制并发大小和批次大小。

还有升级空间,比如:请求失败重试、请求失败任务保存,请求失败任务恢复等功能。

import aiohttp
import asyncio
import json
import os
import time
from typing import Any, Callable, Dict, List, Optional


class AdvancedAsyncScraper:
    def __init__(
            self,
            request_builder: Callable[[Dict], Dict[str, Any]],
            response_handler: Callable[[Dict, Dict], Any],
            task_list: List[Dict],
            config: Dict[str, Any]
    ):
        """
        增强版异步爬虫框架(支持复杂任务字典)

        :param request_builder: 请求构造器 (task -> 请求配置字典)
        :param response_handler: 响应处理器 (响应数据, task -> 处理结果)
        :param task_list: 任务字典列表
        :param config: 全局配置
        """
        self.request_builder = request_builder
        self.response_handler = response_handler
        self.tasks = task_list
        self.config = config

        # 初始化输出目录
        os.makedirs(config.get("output_dir", "results"), exist_ok=True)

    async def process_task(
            self,
            session: aiohttp.ClientSession,
            semaphore: asyncio.Semaphore,
            task: Dict
    ):
        """处理单个任务的完整流程"""
        task_id = task.get('task_id', str(task)[:50])
        start_time = time.time()

        try:
            async with semaphore:
                self._log(f"开始处理任务: {task_id}")

                # 构造请求参数
                request_config = self.request_builder(task)

                # 执行请求
                response_data = await self._execute_request(session, task, request_config)
                if not response_data:
                    return

                # 处理响应
                self.response_handler(response_data, task)

                # 记录耗时
                cost = time.time() - start_time
                self._log(f"任务完成: {task_id} | 耗时: {cost:.2f}s")

        except Exception as e:
            self._log(f"任务失败 [{task_id}]: {str(e)}")

    async def run(self):
        """启动任务处理器"""
        self._log(f"启动处理器,总任务数: {len(self.tasks)}")

        async with aiohttp.ClientSession(
                connector=aiohttp.TCPConnector(ssl=False),
                headers=self.config.get("default_headers")
        ) as session:

            semaphore = asyncio.Semaphore(self.config["concurrency"])
            batch_size = self.config.get("batch_size", 100)

            # 分批处理任务
            for i in range(0, len(self.tasks), batch_size):
                batch = self.tasks[i:i + batch_size]
                self._log(f"处理批次: {i + 1}-{i + len(batch)}")

                await asyncio.gather(*[
                    self.process_task(session, semaphore, task)
                    for task in batch
                ])

                if self.config.get("batch_delay"):
                    await asyncio.sleep(self.config["batch_delay"])

        self._log("所有任务处理完成")

    async def _execute_request(self, session, task, config):
        """执行实际请求"""
        try:
            async with session.request(
                    method=config["method"],
                    url=config["url"],
                    headers=config.get("headers"),
                    params=config.get("params"),
                    json=config.get("json"),
                    data=config.get("data"),
                    timeout=aiohttp.ClientTimeout(total=self.config["request_timeout"])
            ) as resp:
                return await self._parse_response(resp, task)
        except Exception as e:
            self._log(f"请求异常 [{task.get('task_id')}]: {str(e)}")
            return None

    async def _parse_response(self, resp, task):
        """解析响应数据"""
        try:
            content = await resp.text()
            if resp.content_type == 'application/json':
                return json.loads(content)
            return {"status": resp.status, "raw_data": content}
        except Exception as e:
            self._log(f"响应解析失败 [{task.get('task_id')}]: {str(e)}")
            return None

    def _log(self, message):
        """统一日志格式"""
        print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}")


# ----------------------------
# 使用示例
# ----------------------------
def complex_request_builder(task: Dict) -> Dict:
    """请求构造示例"""
    # 根据不同任务类型构建请求
    base_url = "https://www.baidu.com/"

    # if task["type"] == "user_detail":
    #     return {
    #         "method": "GET",
    #         "url": f"{base_url}/users/{task['user_id']}",
    #         "params": {"fields": "all"}
    #     }
    # elif task["type"] == "order_list":
    #     return {
    #         "method": "POST",
    #         "url": f"{base_url}/orders/search",
    #         "json": {
    #             "status": task["status"],
    #             "page": task.get("page", 1)
    #         }
    #     }
    # else:
    #     raise ValueError("未知任务类型")

    return {
        "method": "GET",
        "url": f"{base_url}?i={task['id']}"
    }

def response_processor(data: Dict, task: Dict) -> Dict:
    """响应处理示例"""
    print(f"任务:{task['id']} 完成: {data['status']}")


if __name__ == "__main__":
    # 创建任务
    task_pool = []

    for i in range(10):
        task_pool.append({
            "task_id": i,
            "id": i
        })

    # 全局配置
    config = {
        "concurrency": 3,           # 并发
        "request_timeout": 20,      # Http超时
        "batch_size": 10,           # 批次大小
        "batch_delay": 1,           # 批次间延时
        "default_headers": {        # 请求头
            "Authorization": "Bearer YOUR_TOKEN",
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
        }
    }

    # 初始化并运行
    scraper = AdvancedAsyncScraper(
        request_builder=complex_request_builder,    # 请求生成器
        response_handler=response_processor,        # 响应处理器
        task_list=task_pool,    # 任务列表
        config=config   # 配置
    )

    try:
        # win中必须使用 否则报错
        asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
        asyncio.run(scraper.run())
    except KeyboardInterrupt:
        print("\n[系统] 用户中断执行")