# import csv # from typing import List, Dict, Optional # from datetime import datetime # from base import LogHelper,MySQLHelper # logger = LogHelper(logger_name = 'export').setup() # def get_monthly_avg_data(db_config: dict, table_name: str) -> Optional[List[Dict]]: # """ # 从数据库读取月度均值数据 # Args: # db_config: 数据库配置 # table_name: 源数据表名 # Returns: # List[Dict]: 查询结果数据集,失败返回None # """ # try: # with MySQLHelper(**db_config) as db: # # 获取表结构信息 # columns = db.execute_query(f""" # SELECT COLUMN_NAME # FROM INFORMATION_SCHEMA.COLUMNS # WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s # ORDER BY ORDINAL_POSITION # """, (db_config['database'], table_name)) # if not columns: # logger.error(f"表 {table_name} 不存在或没有列") # return None # # 获取列名列表(排除id和update_time) # field_names = [col['COLUMN_NAME'] for col in columns # if col['COLUMN_NAME'] not in ('id', 'update_time')] # # 查询数据 # data = db.execute_query(f""" # SELECT {', '.join(field_names)} # FROM {table_name} # ORDER BY stock_code # """) # if not data: # logger.error(f"表 {table_name} 中没有数据") # return None # return data # except Exception as e: # logger.error(f"从数据库读取月度均值数据失败: {str(e)}") # return None # def get_float_share_data(db_config: dict, table_name: str) -> Optional[List[Dict]]: # """ # 从conditionalselection表读取流通股本数据 # Args: # db_config: 数据库配置 # table_name: 源数据表名 # Returns: # List[Dict]: 查询结果数据集,失败返回None # """ # try: # with MySQLHelper(**db_config) as db: # # 查询流通股本数据 # data = db.execute_query(f""" # SELECT stock_code, stock_name, float_share # FROM {table_name} # ORDER BY stock_code # """) # if not data: # logger.error(f"表 {table_name} 中没有流通股本数据") # return None # return data # except Exception as e: # logger.error(f"从数据库读取流通股本数据失败: {str(e)}") # return None # def export_to_csv(data: List[Dict], output_file: str) -> bool: # """ # 将合并后的数据导出到CSV文件 # Args: # data: 要导出的数据集 # output_file: 输出的CSV文件路径 # Returns: # bool: 是否导出成功 # """ # if not data: # return False # try: # # 获取字段名(使用第一个数据的键) # field_names = list(data[0].keys()) # # 字段名到中文的映射 # header_map = { # 'stock_code': '股票代码', # 'stock_name': '股票名称', # 'ym_2410': '2024年10月', # 'ym_2411': '2024年11月', # 'ym_2412': '2024年12月', # 'ym_2501': '2025年01月', # 'ym_2502': '2025年02月', # 'ym_2503': '2025年03月', # 'ym_2504': '2025年04月', # 'ym_2505': '2025年05月', # 'ym_2506': '2025年06月', # 'ym_2507': '2025年07月', # 'ym_2508': '2025年08月', # 'avg_all': '月度均值' # } # with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile: # writer = csv.DictWriter(csvfile, fieldnames=field_names) # # 写入中文表头 # writer.writerow({col: header_map.get(col, col) for col in field_names}) # # 写入数据 # writer.writerows(data) # logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}") # return True # except Exception as e: # logger.error(f"导出到CSV失败: {str(e)}") # return False # def export_data(db_config: dict, # monthly_table: str, # csv_file: str = None) -> bool: # """ # 导出合并后的数据到CSV和/或Excel # Args: # db_config: 数据库配置 # monthly_table: 月度均价表名 # float_share_table: 流通股本表名 # csv_file: CSV输出路径(可选) # excel_file: Excel输出路径(可选) # Returns: # bool: 是否至少有一种格式导出成功 # """ # # 从数据库获取数据 # monthly_data = get_monthly_avg_data(db_config, monthly_table) # if not monthly_data: # logger.error("无法获取月度均价数据") # return False # # 导出结果 # filePath = 'data/' + csv_file # csv_success = True # if csv_file: # csv_success = export_to_csv(monthly_data, filePath) # return csv_success import csv from typing import List, Dict, Optional from base import LogHelper, MySQLHelper class DataExporter: """数据导出器类,用于从数据库提取数据并导出到CSV文件""" # 表头映射配置 HEADER_MAP = { 'stock_code': '股票代码', 'stock_name': '股票名称', 'ym_2410': '2024年10月', 'ym_2411': '2024年11月', 'ym_2412': '2024年12月', 'ym_2501': '2025年01月', 'ym_2502': '2025年02月', 'ym_2503': '2025年03月', 'ym_2504': '2025年04月', 'ym_2505': '2025年05月', 'ym_2506': '2025年06月', 'ym_2507': '2025年07月', 'ym_2508': '2025年08月', 'avg_all': '月度均值' } def __init__(self, db_config: dict, logger_name: str = 'export'): """ 初始化数据导出器 Args: db_config: 数据库配置字典 logger_name: 日志记录器名称 """ self.db_config = db_config self.logger = LogHelper(logger_name=logger_name).setup() def get_monthly_avg_data(self, table_name: str) -> Optional[List[Dict]]: """ 从数据库读取月度均值数据 Args: table_name: 源数据表名 Returns: List[Dict]: 查询结果数据集,失败返回None """ try: with MySQLHelper(**self.db_config) as db: # 获取表结构信息 columns = db.execute_query(f""" SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION """, (self.db_config['database'], table_name)) if not columns: self.logger.error(f"表 {table_name} 不存在或没有列") return None # 获取列名列表(排除id和update_time) field_names = [col['COLUMN_NAME'] for col in columns if col['COLUMN_NAME'] not in ('id', 'update_time')] # 查询数据 data = db.execute_query(f""" SELECT {', '.join(field_names)} FROM {table_name} ORDER BY stock_code """) if not data: self.logger.error(f"表 {table_name} 中没有数据") return None return data except Exception as e: self.logger.error(f"从数据库读取月度均值数据失败: {str(e)}") return None def get_float_share_data(self, table_name: str) -> Optional[List[Dict]]: """ 从conditionalselection表读取流通股本数据 Args: table_name: 源数据表名 Returns: List[Dict]: 查询结果数据集,失败返回None """ try: with MySQLHelper(**self.db_config) as db: # 查询流通股本数据 data = db.execute_query(f""" SELECT stock_code, stock_name, float_share FROM {table_name} ORDER BY stock_code """) if not data: self.logger.error(f"表 {table_name} 中没有流通股本数据") return None return data except Exception as e: self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}") return None def export_to_csv(self, data: List[Dict], output_file: str) -> bool: """ 将合并后的数据导出到CSV文件 Args: data: 要导出的数据集 output_file: 输出的CSV文件路径 Returns: bool: 是否导出成功 """ if not data: self.logger.warning("没有数据可导出") return False try: # 获取字段名(使用第一个数据的键) field_names = list(data[0].keys()) with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=field_names) # 写入中文表头 writer.writerow({col: self.HEADER_MAP.get(col, col) for col in field_names}) # 写入数据 writer.writerows(data) self.logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}") return True except Exception as e: self.logger.error(f"导出到CSV失败: {str(e)}") return False def export_data(self, monthly_table: str, csv_file: str = None) -> bool: """ 导出合并后的数据到CSV Args: monthly_table: 月度均价表名 csv_file: CSV输出路径(可选) Returns: bool: 是否导出成功 """ # 从数据库获取数据 monthly_data = self.get_monthly_avg_data(monthly_table) if not monthly_data: self.logger.error("无法获取月度均价数据") return False # 导出结果 file_path = 'data/' + csv_file if csv_file else None csv_success = True if csv_file: csv_success = self.export_to_csv(monthly_data, file_path) return csv_success