import csv import pandas as pd from MySQLHelper import MySQLHelper import logging from typing import List, Dict, Optional, Tuple from datetime import datetime def get_monthly_avg_data(db_config: dict, table_name: str) -> Optional[List[Dict]]: """ 从数据库读取月度均值数据 Args: db_config: 数据库配置 table_name: 源数据表名 Returns: List[Dict]: 查询结果数据集,失败返回None """ try: with MySQLHelper(**db_config) as db: # 获取表结构信息 columns = db.execute_query(f""" SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION """, (db_config['database'], table_name)) if not columns: logging.error(f"表 {table_name} 不存在或没有列") return None # 获取列名列表(排除id和update_time) field_names = [col['COLUMN_NAME'] for col in columns if col['COLUMN_NAME'] not in ('id', 'update_time')] # 查询数据 data = db.execute_query(f""" SELECT {', '.join(field_names)} FROM {table_name} ORDER BY stock_code """) if not data: logging.error(f"表 {table_name} 中没有数据") return None return data except Exception as e: logging.error(f"从数据库读取月度均值数据失败: {str(e)}") return None def get_float_share_data(db_config: dict, table_name: str) -> Optional[List[Dict]]: """ 从conditionalselection表读取流通股本数据 Args: db_config: 数据库配置 table_name: 源数据表名 Returns: List[Dict]: 查询结果数据集,失败返回None """ try: with MySQLHelper(**db_config) as db: # 查询流通股本数据 data = db.execute_query(f""" SELECT stock_code, stock_name, float_share FROM {table_name} ORDER BY stock_code """) if not data: logging.error(f"表 {table_name} 中没有流通股本数据") return None return data except Exception as e: logging.error(f"从数据库读取流通股本数据失败: {str(e)}") return None def merge_data(monthly_data: List[Dict], float_share_data: List[Dict]) -> List[Dict]: """ 合并月度均价数据和流通股本数据 Args: monthly_data: 月度均价数据 float_share_data: 流通股本数据 Returns: List[Dict]: 合并后的数据集 """ merged_data = [] float_share_dict = {item['stock_code']: item['float_share'] for item in float_share_data} for item in monthly_data: merged_item = item.copy() merged_item['float_share'] = float_share_dict.get(item['stock_code'], 'N/A') merged_data.append(merged_item) return merged_data def export_to_csv(data: List[Dict], output_file: str) -> bool: """ 将合并后的数据导出到CSV文件 Args: data: 要导出的数据集 output_file: 输出的CSV文件路径 Returns: bool: 是否导出成功 """ if not data: return False try: # 获取字段名(使用第一个数据的键) field_names = list(data[0].keys()) # 字段名到中文的映射 header_map = { 'stock_code': '股票代码', 'stock_name': '股票名称', 'float_share': '流通股本(千股)', 'ym_2410': '2024年10月均收盘价', 'ym_2411': '2024年11月均收盘价', 'ym_2412': '2024年12月均收盘价', 'ym_2501': '2025年1月均收盘价', 'ym_2502': '2025年2月均收盘价', 'ym_2503': '2025年3月均收盘价', 'ym_2504': '2025年4月均收盘价', 'ym_2505': '2025年5月均收盘价', 'ym_2506': '2025年6月均收盘价', 'ym_2507': '2025年7月均收盘价', 'ym_2508': '2025年8月均收盘价' } with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=field_names) # 写入中文表头 writer.writerow({col: header_map.get(col, col) for col in field_names}) # 写入数据 writer.writerows(data) logging.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}") return True except Exception as e: logging.error(f"导出到CSV失败: {str(e)}") return False def export_to_excel(data: List[Dict], output_file: str) -> bool: """ 将合并后的数据导出为Excel文件(包含多个工作表) Args: data: 要导出的数据集 output_file: 输出的Excel文件路径 Returns: bool: 是否导出成功 """ if not data: return False try: # 转换为DataFrame df = pd.DataFrame(data) # 设置股票代码为索引 if 'stock_code' in df.columns: df.set_index('stock_code', inplace=True) # 创建Excel writer对象 with pd.ExcelWriter(output_file, engine='openpyxl') as writer: # 1. 原始数据工作表 df.to_excel(writer, sheet_name='合并数据') # 2. 统计信息工作表(仅当有数值列时) numeric_cols = [col for col in df.columns if col.startswith('ym_') and pd.api.types.is_numeric_dtype(df[col])] if numeric_cols: try: stats = df[numeric_cols].describe().loc[['mean', 'min', 'max', 'std']] stats.to_excel(writer, sheet_name='统计信息') except KeyError: logging.warning("无法生成完整的统计信息,数据可能不足") # 生成简化版统计信息 stats = df[numeric_cols].agg(['mean', 'min', 'max', 'std']) stats.to_excel(writer, sheet_name='统计信息') # 3. 涨幅排名工作表(需要至少两个月份数据) if len(numeric_cols) >= 2: first_month = numeric_cols[0] last_month = numeric_cols[-1] try: df['涨幅(%)'] = (df[last_month] - df[first_month]) / df[first_month] * 100 result_df = df[['stock_name', '涨幅(%)', 'float_share']].copy() result_df.dropna(subset=['涨幅(%)'], inplace=True) result_df.sort_values('涨幅(%)', ascending=False, inplace=True) result_df.to_excel(writer, sheet_name='涨幅排名') except Exception as e: logging.warning(f"无法计算涨幅: {str(e)}") # 4. 月度趋势工作表 if numeric_cols: try: trend_df = df[numeric_cols].transpose() trend_df.index = [col.replace('ym_', '') for col in numeric_cols] trend_df.to_excel(writer, sheet_name='月度趋势') except Exception as e: logging.warning(f"无法生成月度趋势: {str(e)}") # 5. 流通股本分析工作表 if 'float_share' in df.columns and pd.api.types.is_numeric_dtype(df['float_share']): try: float_stats = df['float_share'].describe().to_frame().T float_stats.to_excel(writer, sheet_name='流通股本分析') except Exception as e: logging.warning(f"无法生成流通股本分析: {str(e)}") logging.info(f"成功导出Excel文件: {output_file}") return True except Exception as e: logging.error(f"导出Excel失败: {str(e)}") return False def export_combined_data(db_config: dict, monthly_table: str, float_share_table: str, csv_file: str = None, excel_file: str = None) -> bool: """ 导出合并后的数据到CSV和/或Excel Args: db_config: 数据库配置 monthly_table: 月度均价表名 float_share_table: 流通股本表名 csv_file: CSV输出路径(可选) excel_file: Excel输出路径(可选) Returns: bool: 是否至少有一种格式导出成功 """ if not csv_file and not excel_file: logging.error("必须指定至少一种输出格式") return False # 从数据库获取数据 monthly_data = get_monthly_avg_data(db_config, monthly_table) if not monthly_data: logging.error("无法获取月度均价数据") return False float_share_data = get_float_share_data(db_config, float_share_table) if not float_share_data: logging.error("无法获取流通股本数据") return False # 合并数据 merged_data = merge_data(monthly_data, float_share_data) # 导出结果 csv_success = True excel_success = True if csv_file: csv_success = export_to_csv(merged_data, csv_file) if excel_file: excel_success = export_to_excel(merged_data, excel_file) return csv_success or excel_success # 主程序入口 if __name__ == "__main__": # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('export_combined_data.log', encoding='utf-8'), logging.StreamHandler() ] ) # 数据库配置 db_config = { 'host': 'localhost', 'user': 'root', 'password': 'bzskmysql', 'database': 'klinedata_1d_hk_akshare' } # 导出合并数据 success = export_combined_data( db_config=db_config, monthly_table="hk_monthly_avg_2410_2508", float_share_table="conditionalselection", csv_file="hk_stocks_combined_data.csv", excel_file="hk_stocks_combined_data.xlsx" ) if success: logging.info("数据合并导出成功完成") else: logging.error("数据合并导出过程中出现错误")