Files
HKDataManagment/DataAnalysis/DataExporter.py
2025-08-22 11:20:41 +08:00

336 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# import csv
# from typing import List, Dict, Optional
# from datetime import datetime
# from base import LogHelper,MySQLHelper
# logger = LogHelper(logger_name = 'export').setup()
# def get_monthly_avg_data(db_config: dict, table_name: str) -> Optional[List[Dict]]:
# """
# 从数据库读取月度均值数据
# Args:
# db_config: 数据库配置
# table_name: 源数据表名
# Returns:
# List[Dict]: 查询结果数据集失败返回None
# """
# try:
# with MySQLHelper(**db_config) as db:
# # 获取表结构信息
# columns = db.execute_query(f"""
# SELECT COLUMN_NAME
# FROM INFORMATION_SCHEMA.COLUMNS
# WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
# ORDER BY ORDINAL_POSITION
# """, (db_config['database'], table_name))
# if not columns:
# logger.error(f"表 {table_name} 不存在或没有列")
# return None
# # 获取列名列表(排除id和update_time)
# field_names = [col['COLUMN_NAME'] for col in columns
# if col['COLUMN_NAME'] not in ('id', 'update_time')]
# # 查询数据
# data = db.execute_query(f"""
# SELECT {', '.join(field_names)}
# FROM {table_name}
# ORDER BY stock_code
# """)
# if not data:
# logger.error(f"表 {table_name} 中没有数据")
# return None
# return data
# except Exception as e:
# logger.error(f"从数据库读取月度均值数据失败: {str(e)}")
# return None
# def get_float_share_data(db_config: dict, table_name: str) -> Optional[List[Dict]]:
# """
# 从conditionalselection表读取流通股本数据
# Args:
# db_config: 数据库配置
# table_name: 源数据表名
# Returns:
# List[Dict]: 查询结果数据集失败返回None
# """
# try:
# with MySQLHelper(**db_config) as db:
# # 查询流通股本数据
# data = db.execute_query(f"""
# SELECT stock_code, stock_name, float_share
# FROM {table_name}
# ORDER BY stock_code
# """)
# if not data:
# logger.error(f"表 {table_name} 中没有流通股本数据")
# return None
# return data
# except Exception as e:
# logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
# return None
# def export_to_csv(data: List[Dict], output_file: str) -> bool:
# """
# 将合并后的数据导出到CSV文件
# Args:
# data: 要导出的数据集
# output_file: 输出的CSV文件路径
# Returns:
# bool: 是否导出成功
# """
# if not data:
# return False
# try:
# # 获取字段名(使用第一个数据的键)
# field_names = list(data[0].keys())
# # 字段名到中文的映射
# header_map = {
# 'stock_code': '股票代码',
# 'stock_name': '股票名称',
# 'ym_2410': '2024年10月',
# 'ym_2411': '2024年11月',
# 'ym_2412': '2024年12月',
# 'ym_2501': '2025年01月',
# 'ym_2502': '2025年02月',
# 'ym_2503': '2025年03月',
# 'ym_2504': '2025年04月',
# 'ym_2505': '2025年05月',
# 'ym_2506': '2025年06月',
# 'ym_2507': '2025年07月',
# 'ym_2508': '2025年08月',
# 'avg_all': '月度均值'
# }
# with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile:
# writer = csv.DictWriter(csvfile, fieldnames=field_names)
# # 写入中文表头
# writer.writerow({col: header_map.get(col, col) for col in field_names})
# # 写入数据
# writer.writerows(data)
# logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}")
# return True
# except Exception as e:
# logger.error(f"导出到CSV失败: {str(e)}")
# return False
# def export_data(db_config: dict,
# monthly_table: str,
# csv_file: str = None) -> bool:
# """
# 导出合并后的数据到CSV和/或Excel
# Args:
# db_config: 数据库配置
# monthly_table: 月度均价表名
# float_share_table: 流通股本表名
# csv_file: CSV输出路径(可选)
# excel_file: Excel输出路径(可选)
# Returns:
# bool: 是否至少有一种格式导出成功
# """
# # 从数据库获取数据
# monthly_data = get_monthly_avg_data(db_config, monthly_table)
# if not monthly_data:
# logger.error("无法获取月度均价数据")
# return False
# # 导出结果
# filePath = 'data/' + csv_file
# csv_success = True
# if csv_file:
# csv_success = export_to_csv(monthly_data, filePath)
# return csv_success
import csv
from typing import List, Dict, Optional
from base import LogHelper, MySQLHelper
class DataExporter:
"""数据导出器类用于从数据库提取数据并导出到CSV文件"""
# 表头映射配置
HEADER_MAP = {
'stock_code': '股票代码',
'stock_name': '股票名称',
'ym_2410': '2024年10月',
'ym_2411': '2024年11月',
'ym_2412': '2024年12月',
'ym_2501': '2025年01月',
'ym_2502': '2025年02月',
'ym_2503': '2025年03月',
'ym_2504': '2025年04月',
'ym_2505': '2025年05月',
'ym_2506': '2025年06月',
'ym_2507': '2025年07月',
'ym_2508': '2025年08月',
'avg_all': '月度均值'
}
def __init__(self, db_config: dict, logger_name: str = 'export'):
"""
初始化数据导出器
Args:
db_config: 数据库配置字典
logger_name: 日志记录器名称
"""
self.db_config = db_config
self.logger = LogHelper(logger_name=logger_name).setup()
def get_monthly_avg_data(self, table_name: str) -> Optional[List[Dict]]:
"""
从数据库读取月度均值数据
Args:
table_name: 源数据表名
Returns:
List[Dict]: 查询结果数据集失败返回None
"""
try:
with MySQLHelper(**self.db_config) as db:
# 获取表结构信息
columns = db.execute_query(f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION
""", (self.db_config['database'], table_name))
if not columns:
self.logger.error(f"{table_name} 不存在或没有列")
return None
# 获取列名列表(排除id和update_time)
field_names = [col['COLUMN_NAME'] for col in columns
if col['COLUMN_NAME'] not in ('id', 'update_time')]
# 查询数据
data = db.execute_query(f"""
SELECT {', '.join(field_names)}
FROM {table_name}
ORDER BY stock_code
""")
if not data:
self.logger.error(f"{table_name} 中没有数据")
return None
return data
except Exception as e:
self.logger.error(f"从数据库读取月度均值数据失败: {str(e)}")
return None
def get_float_share_data(self, table_name: str) -> Optional[List[Dict]]:
"""
从conditionalselection表读取流通股本数据
Args:
table_name: 源数据表名
Returns:
List[Dict]: 查询结果数据集失败返回None
"""
try:
with MySQLHelper(**self.db_config) as db:
# 查询流通股本数据
data = db.execute_query(f"""
SELECT stock_code, stock_name, float_share
FROM {table_name}
ORDER BY stock_code
""")
if not data:
self.logger.error(f"{table_name} 中没有流通股本数据")
return None
return data
except Exception as e:
self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
return None
def export_to_csv(self, data: List[Dict], output_file: str) -> bool:
"""
将合并后的数据导出到CSV文件
Args:
data: 要导出的数据集
output_file: 输出的CSV文件路径
Returns:
bool: 是否导出成功
"""
if not data:
self.logger.warning("没有数据可导出")
return False
try:
# 获取字段名(使用第一个数据的键)
field_names = list(data[0].keys())
with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=field_names)
# 写入中文表头
writer.writerow({col: self.HEADER_MAP.get(col, col) for col in field_names})
# 写入数据
writer.writerows(data)
self.logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}")
return True
except Exception as e:
self.logger.error(f"导出到CSV失败: {str(e)}")
return False
def export_data(self, monthly_table: str, csv_file: str = None) -> bool:
"""
导出合并后的数据到CSV
Args:
monthly_table: 月度均价表名
csv_file: CSV输出路径(可选)
Returns:
bool: 是否导出成功
"""
# 从数据库获取数据
monthly_data = self.get_monthly_avg_data(monthly_table)
if not monthly_data:
self.logger.error("无法获取月度均价数据")
return False
# 导出结果
file_path = 'data/' + csv_file if csv_file else None
csv_success = True
if csv_file:
csv_success = self.export_to_csv(monthly_data, file_path)
return csv_success