236 lines
7.7 KiB
Python
236 lines
7.7 KiB
Python
from futu import *
|
||
from base.MySQLHelper import MySQLHelper # MySQLHelper类保存为单独文件
|
||
from base.LogHelper import LogHelper
|
||
from datetime import datetime
|
||
from typing import Optional, List, Dict
|
||
from ConditionalSelection import FutuStockFilter
|
||
from tqdm import tqdm
|
||
import pandas as pd
|
||
import time
|
||
import csv
|
||
import pandas as pd
|
||
|
||
def get_monthly_avg_data(db_config: dict, table_name: str) -> Optional[List[Dict]]:
|
||
"""
|
||
从数据库读取月度均值数据
|
||
|
||
Args:
|
||
db_config: 数据库配置
|
||
table_name: 源数据表名
|
||
|
||
Returns:
|
||
List[Dict]: 查询结果数据集,失败返回None
|
||
"""
|
||
try:
|
||
with MySQLHelper(**db_config) as db:
|
||
# 获取表结构信息
|
||
columns = db.execute_query(f"""
|
||
SELECT COLUMN_NAME
|
||
FROM INFORMATION_SCHEMA.COLUMNS
|
||
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
||
ORDER BY ORDINAL_POSITION
|
||
""", (db_config['database'], table_name))
|
||
|
||
if not columns:
|
||
logger.error(f"表 {table_name} 不存在或没有列")
|
||
return None
|
||
|
||
# 获取列名列表(排除id和update_time)
|
||
field_names = [col['COLUMN_NAME'] for col in columns
|
||
if col['COLUMN_NAME'] not in ('id', 'update_time')]
|
||
|
||
# 查询数据
|
||
data = db.execute_query(f"""
|
||
SELECT {', '.join(field_names)}
|
||
FROM {table_name}
|
||
ORDER BY stock_code
|
||
""")
|
||
|
||
if not data:
|
||
logger.error(f"表 {table_name} 中没有数据")
|
||
return None
|
||
|
||
return data
|
||
|
||
except Exception as e:
|
||
logger.error(f"从数据库读取月度均值数据失败: {str(e)}")
|
||
return None
|
||
|
||
def get_float_share_data(db_config: dict, table_name: str) -> Optional[List[Dict]]:
|
||
"""
|
||
从conditionalselection表读取流通股本数据
|
||
|
||
Args:
|
||
db_config: 数据库配置
|
||
table_name: 源数据表名
|
||
|
||
Returns:
|
||
List[Dict]: 查询结果数据集,失败返回None
|
||
"""
|
||
try:
|
||
with MySQLHelper(**db_config) as db:
|
||
# 查询流通股本数据
|
||
data = db.execute_query(f"""
|
||
SELECT stock_code, stock_name, float_share
|
||
FROM {table_name}
|
||
ORDER BY stock_code
|
||
""")
|
||
|
||
if not data:
|
||
logger.error(f"表 {table_name} 中没有流通股本数据")
|
||
return None
|
||
|
||
return data
|
||
|
||
except Exception as e:
|
||
logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
|
||
return None
|
||
|
||
def export_to_csv(data: List[Dict], output_file: str, mode: str = 'a') -> bool:
|
||
"""
|
||
将数据导出到CSV文件,支持追加模式
|
||
|
||
Args:
|
||
data: 要导出的数据集
|
||
output_file: 输出的CSV文件路径
|
||
mode: 写入模式,'w'为覆盖写入,'a'为追加写入
|
||
|
||
Returns:
|
||
bool: 是否导出成功
|
||
"""
|
||
if not data:
|
||
logger.warning("没有数据可导出")
|
||
return False
|
||
|
||
try:
|
||
# 获取字段名(使用第一个数据的键)
|
||
field_names = list(data[0].keys())
|
||
|
||
# 字段名到中文的映射
|
||
header_map = {
|
||
'stock_code': '股票代码',
|
||
'stock_name': '股票名称',
|
||
'circular_market_val': '流通市值'
|
||
}
|
||
|
||
# 检查文件是否存在,决定是否需要写入表头
|
||
file_exists = os.path.isfile(output_file)
|
||
|
||
with open(output_file, mode=mode, newline='', encoding='utf-8-sig') as csvfile:
|
||
writer = csv.DictWriter(csvfile, fieldnames=field_names)
|
||
|
||
# # 如果是新文件或覆盖模式,写入表头
|
||
# if not file_exists or mode == 'w':
|
||
# # 写入中文表头
|
||
# writer.writerow({col: header_map.get(col, col) for col in field_names})
|
||
|
||
# 写入数据
|
||
writer.writerows(data)
|
||
|
||
logger.info(f"成功{'追加' if mode == 'a' and file_exists else '写入'} {len(data)} 条记录到CSV文件: {output_file}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"导出到CSV失败: {str(e)}")
|
||
return False
|
||
|
||
|
||
def export_data(db_config: dict,
|
||
monthly_table: str,
|
||
csv_file: str = None) -> bool:
|
||
"""
|
||
导出合并后的数据到CSV和/或Excel
|
||
|
||
Args:
|
||
db_config: 数据库配置
|
||
monthly_table: 月度均价表名
|
||
float_share_table: 流通股本表名
|
||
csv_file: CSV输出路径(可选)
|
||
excel_file: Excel输出路径(可选)
|
||
|
||
Returns:
|
||
bool: 是否至少有一种格式导出成功
|
||
"""
|
||
|
||
# 从数据库获取数据
|
||
monthly_data = get_monthly_avg_data(db_config, monthly_table)
|
||
if not monthly_data:
|
||
logger.error("无法获取月度均价数据")
|
||
return False
|
||
|
||
# 导出结果
|
||
filePath = 'data/' + csv_file
|
||
csv_success = True
|
||
if csv_file:
|
||
csv_success = export_to_csv(monthly_data, filePath)
|
||
|
||
return csv_success
|
||
|
||
def get_stock_codes() -> List[str]:
|
||
"""从conditionalselection表获取所有股票代码"""
|
||
try:
|
||
with MySQLHelper(**db_config) as db:
|
||
sql = f"SELECT DISTINCT stock_code,stock_name FROM stock_filter"
|
||
results = db.execute_query(sql)
|
||
return [row['stock_code'] for row in results if row['stock_code']and row['stock_name'][-1] != 'R']
|
||
except Exception as e:
|
||
logger.error(f"获取股票代码失败: {str(e)}")
|
||
return []
|
||
|
||
def get_kdata_codes(tablename: str = "hk_00001") -> List[str]:
|
||
"""从conditionalselection表获取所有股票代码"""
|
||
try:
|
||
with MySQLHelper(**db_config) as db:
|
||
sql = """
|
||
SELECT DISTINCT close_price
|
||
FROM {}
|
||
WHERE trade_date BETWEEN %s AND %s
|
||
AND close_price IS NOT NULL
|
||
""".format(tablename)
|
||
result = db.execute_query(sql, ('2025-08-01', '2025-08-31'))
|
||
|
||
return [row['close_price'] for row in result if row['close_price']]
|
||
except Exception as e:
|
||
logger.error(f"获取股票代码失败: {str(e)}")
|
||
return []
|
||
|
||
def get_floatshare_codes(tablename: str = "hk_00001") -> List[str]:
|
||
"""从conditionalselection表获取所有股票代码"""
|
||
try:
|
||
with MySQLHelper(**db_config) as db:
|
||
sql = """
|
||
SELECT DISTINCT float_share
|
||
FROM {}
|
||
WHERE trade_date BETWEEN %s AND %s
|
||
AND float_share IS NOT NULL
|
||
""".format(tablename)
|
||
result = db.execute_query(sql, ('2025-08-01', '2025-08-31'))
|
||
|
||
return [row['float_share'] for row in result if row['float_share']]
|
||
except Exception as e:
|
||
logger.error(f"获取股票代码失败: {str(e)}")
|
||
return []
|
||
|
||
from sqlalchemy import create_engine
|
||
if __name__ == "__main__":
|
||
# # 应用配置
|
||
logger = LogHelper(logger_name = 'export').setup()
|
||
|
||
# 数据库配置
|
||
db_config = {
|
||
'host': 'localhost',
|
||
'user': 'root',
|
||
'password': 'bzskmysql',
|
||
'database': 'hk_kline_1d'
|
||
}
|
||
|
||
market_data_all = get_stock_codes()
|
||
for code in tqdm(market_data_all, desc="读取股票数据", unit="支"):
|
||
|
||
custom_table_name = 'hk_' + code[3:] # 自定义表名
|
||
datas = get_kdata_codes(custom_table_name)
|
||
float_share = get_floatshare_codes(custom_table_name)
|
||
|
||
|
||
filePath = 'data/kdata_val.csv'
|
||
|