2025-08-22 11:20:41 +08:00
|
|
|
|
import csv
|
|
|
|
|
|
from typing import List, Dict, Optional
|
2025-08-26 10:32:50 +08:00
|
|
|
|
from base import LogHelper, MySQLHelper, Config
|
2025-08-22 11:20:41 +08:00
|
|
|
|
|
|
|
|
|
|
class DataExporter:
|
|
|
|
|
|
"""数据导出器类,用于从数据库提取数据并导出到CSV文件"""
|
2025-08-26 10:32:50 +08:00
|
|
|
|
|
2025-08-22 11:20:41 +08:00
|
|
|
|
def __init__(self, db_config: dict, logger_name: str = 'export'):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化数据导出器
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db_config: 数据库配置字典
|
|
|
|
|
|
logger_name: 日志记录器名称
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.db_config = db_config
|
|
|
|
|
|
self.logger = LogHelper(logger_name=logger_name).setup()
|
2025-08-26 10:32:50 +08:00
|
|
|
|
self.head_map = Config.ConfigInfo.HEADER_MAP
|
2025-08-22 11:20:41 +08:00
|
|
|
|
|
|
|
|
|
|
def get_monthly_avg_data(self, table_name: str) -> Optional[List[Dict]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
从数据库读取月度均值数据
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
table_name: 源数据表名
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
List[Dict]: 查询结果数据集,失败返回None
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
with MySQLHelper(**self.db_config) as db:
|
|
|
|
|
|
# 获取表结构信息
|
|
|
|
|
|
columns = db.execute_query(f"""
|
|
|
|
|
|
SELECT COLUMN_NAME
|
|
|
|
|
|
FROM INFORMATION_SCHEMA.COLUMNS
|
|
|
|
|
|
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
|
|
|
|
|
ORDER BY ORDINAL_POSITION
|
|
|
|
|
|
""", (self.db_config['database'], table_name))
|
|
|
|
|
|
|
|
|
|
|
|
if not columns:
|
|
|
|
|
|
self.logger.error(f"表 {table_name} 不存在或没有列")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 获取列名列表(排除id和update_time)
|
|
|
|
|
|
field_names = [col['COLUMN_NAME'] for col in columns
|
|
|
|
|
|
if col['COLUMN_NAME'] not in ('id', 'update_time')]
|
|
|
|
|
|
|
2025-08-28 22:09:59 +08:00
|
|
|
|
# 查询数据, 按照 avg_all 字段倒叙排列
|
2025-08-22 11:20:41 +08:00
|
|
|
|
data = db.execute_query(f"""
|
|
|
|
|
|
SELECT {', '.join(field_names)}
|
|
|
|
|
|
FROM {table_name}
|
2025-08-28 22:09:59 +08:00
|
|
|
|
ORDER BY avg_all DESC
|
2025-08-22 11:20:41 +08:00
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
self.logger.error(f"表 {table_name} 中没有数据")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.error(f"从数据库读取月度均值数据失败: {str(e)}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_float_share_data(self, table_name: str) -> Optional[List[Dict]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
从conditionalselection表读取流通股本数据
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
table_name: 源数据表名
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
List[Dict]: 查询结果数据集,失败返回None
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
with MySQLHelper(**self.db_config) as db:
|
|
|
|
|
|
# 查询流通股本数据
|
|
|
|
|
|
data = db.execute_query(f"""
|
|
|
|
|
|
SELECT stock_code, stock_name, float_share
|
|
|
|
|
|
FROM {table_name}
|
|
|
|
|
|
ORDER BY stock_code
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
self.logger.error(f"表 {table_name} 中没有流通股本数据")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def export_to_csv(self, data: List[Dict], output_file: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
将合并后的数据导出到CSV文件
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data: 要导出的数据集
|
|
|
|
|
|
output_file: 输出的CSV文件路径
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
bool: 是否导出成功
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
self.logger.warning("没有数据可导出")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取字段名(使用第一个数据的键)
|
|
|
|
|
|
field_names = list(data[0].keys())
|
|
|
|
|
|
|
|
|
|
|
|
with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile:
|
|
|
|
|
|
writer = csv.DictWriter(csvfile, fieldnames=field_names)
|
|
|
|
|
|
|
|
|
|
|
|
# 写入中文表头
|
2025-08-26 10:32:50 +08:00
|
|
|
|
writer.writerow({col: self.head_map.get(col, col) for col in field_names})
|
2025-08-22 11:20:41 +08:00
|
|
|
|
|
|
|
|
|
|
# 写入数据
|
|
|
|
|
|
writer.writerows(data)
|
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}")
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.logger.error(f"导出到CSV失败: {str(e)}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def export_data(self, monthly_table: str, csv_file: str = None) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
导出合并后的数据到CSV
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
monthly_table: 月度均价表名
|
|
|
|
|
|
csv_file: CSV输出路径(可选)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
bool: 是否导出成功
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 从数据库获取数据
|
|
|
|
|
|
monthly_data = self.get_monthly_avg_data(monthly_table)
|
|
|
|
|
|
if not monthly_data:
|
|
|
|
|
|
self.logger.error("无法获取月度均价数据")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 导出结果
|
|
|
|
|
|
file_path = 'data/' + csv_file if csv_file else None
|
|
|
|
|
|
csv_success = True
|
|
|
|
|
|
if csv_file:
|
|
|
|
|
|
csv_success = self.export_to_csv(monthly_data, file_path)
|
|
|
|
|
|
|
|
|
|
|
|
return csv_success
|