450 lines
17 KiB
Python
450 lines
17 KiB
Python
# """
|
||
# 工具包函数
|
||
|
||
# —— 计算月均流通市值:
|
||
# 一个月内,每天流通市值的均值(当日流通股数量 * 当日股价)
|
||
# —— 根据需要补充 ...
|
||
|
||
# """
|
||
|
||
import pandas as pd
|
||
from datetime import datetime
|
||
from futu import *
|
||
from tqdm import tqdm
|
||
from pathlib import Path
|
||
from typing import Optional, List, Dict, Union, Tuple
|
||
import csv
|
||
from typing import List, Dict, Optional
|
||
from datetime import datetime
|
||
from base import LogHelper, MySQLHelper, ConfigInfo
|
||
|
||
|
||
class MarketDataCalculator:
|
||
"""
|
||
市场数据计算工具类
|
||
|
||
功能:
|
||
- 计算月均流通市值:一个月内,每天流通市值的均值(当日流通股数量 * 当日股价)
|
||
- 数据验证和转换
|
||
- 从Futu API获取市场数据
|
||
- 从数据库获取股票代码
|
||
- 创建和管理数据库表结构
|
||
- 导出数据到CSV文件
|
||
"""
|
||
|
||
def __init__(self, db_config: dict, logger_name: str = 'Calculate'):
|
||
"""
|
||
初始化市场数据计算器
|
||
|
||
Args:
|
||
db_config: 数据库配置字典
|
||
logger_name: 日志记录器名称
|
||
"""
|
||
self.db_config = db_config
|
||
self.logger = LogHelper(logger_name=logger_name).setup()
|
||
self.month_ranges = ConfigInfo.MONTH_RANGES
|
||
self.head_map = ConfigInfo.HEADER_MAP
|
||
|
||
def create_monthly_avg_table(self, target_table: str = "monthly_close_avg") -> bool:
|
||
"""
|
||
创建专门存储2024年10月至2025年8月月度均值的表结构
|
||
|
||
Args:
|
||
target_table: 目标表名
|
||
|
||
Returns:
|
||
bool: 是否创建成功
|
||
"""
|
||
try:
|
||
with MySQLHelper(**self.db_config) as db:
|
||
create_sql = f"""
|
||
CREATE TABLE IF NOT EXISTS {target_table} (
|
||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||
stock_code VARCHAR(20) NOT NULL COMMENT '股票代码',
|
||
stock_name VARCHAR(50) COMMENT '股票名称',
|
||
ym_2501 DECIMAL(20, 5) COMMENT '2025年01月',
|
||
ym_2502 DECIMAL(20, 5) COMMENT '2025年02月',
|
||
ym_2503 DECIMAL(20, 5) COMMENT '2025年03月',
|
||
ym_2504 DECIMAL(20, 5) COMMENT '2025年04月',
|
||
ym_2505 DECIMAL(20, 5) COMMENT '2025年05月',
|
||
ym_2506 DECIMAL(20, 5) COMMENT '2025年06月',
|
||
ym_2507 DECIMAL(20, 5) COMMENT '2025年07月',
|
||
ym_2508 DECIMAL(20, 5) COMMENT '2025年08月',
|
||
ym_2509 DECIMAL(20, 5) COMMENT '2025年09月',
|
||
avg_all DECIMAL(20, 5) COMMENT '月间均值',
|
||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||
UNIQUE KEY uk_stock_code (stock_code)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='月均流通市值表(2024.10-2025.08)'
|
||
"""
|
||
db.execute_update(create_sql)
|
||
# self.logger.info(f"创建/确认表 {target_table} 结构成功")
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"创建表失败: {str(e)}")
|
||
return False
|
||
|
||
def calculate_and_save_monthly_avg(self,
|
||
stock_code: str = "code",
|
||
target_table: str = "monthly_close_avg") -> bool:
|
||
"""
|
||
计算并保存2024年10月至2025年8月的月均流通市值
|
||
|
||
Args:
|
||
source_table: 源数据表名
|
||
target_table: 目标表名
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
try:
|
||
# 确保表结构存在
|
||
if not self.create_monthly_avg_table(target_table):
|
||
return False
|
||
|
||
with MySQLHelper(**self.db_config) as db:
|
||
|
||
# 获取所有股票代码和名称
|
||
sql = """
|
||
SELECT stock_code, stock_name
|
||
FROM stock_filter
|
||
WHERE stock_code = %s
|
||
"""
|
||
|
||
# stock_filter表格可以当成标准表
|
||
stock_info = db.execute_query(sql,(stock_code))
|
||
if len(stock_info) == 0:
|
||
return
|
||
|
||
|
||
stock_code = stock_info[0]['stock_code']
|
||
stock_name = stock_info[0]['stock_name']
|
||
|
||
monthly_data = {'stock_code': stock_code, 'stock_name': stock_name}
|
||
|
||
# 计算每个月的均值
|
||
source_table = 'hk_' + stock_code[3:]
|
||
for month_col, (start_date, end_date) in self.month_ranges.items():
|
||
sql = """
|
||
SELECT AVG(close_price * float_share) as avg_close
|
||
FROM {}
|
||
WHERE stock_code = %s
|
||
AND trade_date BETWEEN %s AND %s
|
||
AND close_price IS NOT NULL
|
||
AND float_share IS NOT NULL
|
||
""".format(source_table)
|
||
|
||
result = db.execute_query(sql, (stock_code, start_date, end_date))
|
||
# 保存小数点后两位,以亿为单位
|
||
# monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None
|
||
monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None
|
||
|
||
|
||
# 提取所有以 'ym_' 开头的键的值
|
||
ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')]
|
||
valid_ym_values = [value for value in ym_values if value is not None]
|
||
|
||
|
||
# 计算全部月的均值
|
||
if valid_ym_values:
|
||
average = sum(valid_ym_values) / len(valid_ym_values)
|
||
monthly_data['avg_all'] = average
|
||
self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}")
|
||
else:
|
||
monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错
|
||
self.logger.warning(f"股票 {stock_code} 没有有效的月度数据")
|
||
|
||
# 插入或更新数据
|
||
upsert_sql = f"""
|
||
INSERT INTO {target_table} (
|
||
stock_code, stock_name,
|
||
ym_2501, ym_2502, ym_2503, ym_2504,
|
||
ym_2505, ym_2506,ym_2507, ym_2508,
|
||
ym_2509,
|
||
avg_all
|
||
) VALUES (
|
||
%(stock_code)s, %(stock_name)s,
|
||
%(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s,
|
||
%(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s,
|
||
%(ym_2509)s,
|
||
%(avg_all)s
|
||
)
|
||
ON DUPLICATE KEY UPDATE
|
||
stock_name = VALUES(stock_name),
|
||
ym_2501 = VALUES(ym_2501),
|
||
ym_2502 = VALUES(ym_2502),
|
||
ym_2503 = VALUES(ym_2503),
|
||
ym_2504 = VALUES(ym_2504),
|
||
ym_2505 = VALUES(ym_2505),
|
||
ym_2506 = VALUES(ym_2506),
|
||
ym_2507 = VALUES(ym_2507),
|
||
ym_2508 = VALUES(ym_2508),
|
||
ym_2509 = VALUES(ym_2509),
|
||
avg_all = VALUES(avg_all),
|
||
update_time = CURRENT_TIMESTAMP
|
||
"""
|
||
db.execute_update(upsert_sql, monthly_data)
|
||
# self.logger.info("月度均值计算和保存完成")
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"计算和保存月度均值失败: {str(e)}")
|
||
return False
|
||
|
||
@staticmethod
|
||
def safe_float(v) -> Optional[float]:
|
||
"""安全转换为float,处理N/A和空值"""
|
||
try:
|
||
return float(v) if pd.notna(v) and str(v).upper() != 'N/A' else None
|
||
except (ValueError, TypeError):
|
||
return None
|
||
|
||
@staticmethod
|
||
def safe_int(v) -> Optional[int]:
|
||
"""安全转换为int,处理N/A和空值"""
|
||
try:
|
||
return int(v) if pd.notna(v) and str(v).upper() != 'N/A' else None
|
||
except (ValueError, TypeError):
|
||
return None
|
||
|
||
@staticmethod
|
||
def safe_parse_date(date_str, date_format='%Y-%m-%d'):
|
||
"""
|
||
安全解析日期字符串
|
||
:param date_str: 日期字符串
|
||
:param date_format: 日期格式
|
||
:return: 解析后的datetime对象或None
|
||
"""
|
||
if not date_str or pd.isna(date_str) or str(date_str).strip() == '':
|
||
return None
|
||
try:
|
||
return datetime.strptime(str(date_str), date_format)
|
||
except ValueError:
|
||
return None
|
||
|
||
def validate_market_data(self, dataset: list) -> list:
|
||
"""
|
||
验证市场数据有效性
|
||
|
||
Args:
|
||
dataset (list): 原始数据集
|
||
|
||
Returns:
|
||
list: 通过验证的数据集
|
||
"""
|
||
validated_data = []
|
||
for item in dataset:
|
||
try:
|
||
# 必要字段检查
|
||
if not item.get('code') or not item.get('name'):
|
||
self.logger.warning(f"跳过无效数据: 缺少必要字段 code或name")
|
||
continue
|
||
|
||
# 筛选股票名称
|
||
if item.get('name')[-1] == 'R':
|
||
continue
|
||
|
||
# 数值范围验证
|
||
if item.get('lot_size') is not None and item['lot_size'] < 0:
|
||
self.logger.warning(f"股票 {item['code']} 的lot_size为负值: {item['lot_size']}")
|
||
item['lot_size'] = None
|
||
|
||
validated_data.append(item)
|
||
except Exception as e:
|
||
self.logger.warning(f"数据验证失败,跳过记录 {item.get('code')}: {str(e)}")
|
||
continue
|
||
|
||
return validated_data
|
||
|
||
def get_market_data(self, market: Market) -> List[str]:
|
||
"""
|
||
从Futu API获取指定市场的股票代码列表
|
||
|
||
Args:
|
||
market (Market): 市场枚举值,如 Market.SH, Market.SZ
|
||
|
||
Returns:
|
||
List[str]: 股票代码列表
|
||
"""
|
||
quote_ctx = OpenQuoteContext(host='127.0.0.1', port=11111)
|
||
try:
|
||
ret, data = quote_ctx.get_stock_basicinfo(market, SecurityType.STOCK)
|
||
if ret == RET_OK:
|
||
# 提取code列并转换为列表
|
||
codes = data['code'].astype(str).tolist()
|
||
self.logger.info(f"获取到 {market} 市场 {len(codes)} 个股票代码")
|
||
return codes
|
||
else:
|
||
self.logger.error(f"获取股票代码失败: {data}")
|
||
return []
|
||
except Exception as e:
|
||
self.logger.error(f"获取股票代码时发生异常: {str(e)}")
|
||
return []
|
||
finally:
|
||
quote_ctx.close()
|
||
|
||
|
||
def get_stock_codes(self) -> List[str]:
|
||
"""从 stock_filter 表获取所有股票代码,使用筛选接口得到的股票列表数据"""
|
||
try:
|
||
with MySQLHelper(**self.db_config) as db:
|
||
sql = f"SELECT DISTINCT stock_code,stock_name FROM stock_filter"
|
||
results = db.execute_query(sql)
|
||
return [
|
||
row['stock_code']
|
||
for row in results
|
||
if row['stock_code'] and (row.get('stock_name', '') and not (row.get('stock_name') and str(row['stock_name'])[-1] == 'R'))
|
||
]
|
||
except Exception as e:
|
||
self.logger.error(f"获取股票代码失败: {str(e)}")
|
||
return []
|
||
|
||
@staticmethod
|
||
def read_stock_codes_list(file_path='Reservedcode.txt'):
|
||
"""基础读取方法 - 按行读取所有内容"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
# 去除每行末尾的换行符,并过滤空行
|
||
codes = [line.strip() for line in lines if line.strip()]
|
||
return codes
|
||
except FileNotFoundError:
|
||
print(f"文件 {file_path} 不存在")
|
||
return []
|
||
except Exception as e:
|
||
print(f"读取文件失败: {str(e)}")
|
||
return []
|
||
|
||
def get_monthly_avg_data(self, table_name: str) -> Optional[List[Dict]]:
|
||
"""
|
||
从数据库读取月度均值数据
|
||
|
||
Args:
|
||
table_name: 源数据表名
|
||
|
||
Returns:
|
||
List[Dict]: 查询结果数据集,失败返回None
|
||
"""
|
||
try:
|
||
with MySQLHelper(**self.db_config) as db:
|
||
# 获取表结构信息
|
||
columns = db.execute_query(f"""
|
||
SELECT COLUMN_NAME
|
||
FROM INFORMATION_SCHEMA.COLUMNS
|
||
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
||
ORDER BY ORDINAL_POSITION
|
||
""", (self.db_config['database'], table_name))
|
||
|
||
if not columns:
|
||
self.logger.error(f"表 {table_name} 不存在或没有列")
|
||
return None
|
||
|
||
# 获取列名列表(排除id和update_time)
|
||
field_names = [col['COLUMN_NAME'] for col in columns
|
||
if col['COLUMN_NAME'] not in ('id', 'update_time')]
|
||
|
||
# 查询数据
|
||
data = db.execute_query(f"""
|
||
SELECT {', '.join(field_names)}
|
||
FROM {table_name}
|
||
ORDER BY stock_code
|
||
""")
|
||
|
||
if not data:
|
||
self.logger.error(f"表 {table_name} 中没有数据")
|
||
return None
|
||
|
||
return data
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"从数据库读取月度均值数据失败: {str(e)}")
|
||
return None
|
||
|
||
def get_float_share_data(self, table_name: str) -> Optional[List[Dict]]:
|
||
"""
|
||
从conditionalselection表读取流通股本数据
|
||
|
||
Args:
|
||
table_name: 源数据表名
|
||
|
||
Returns:
|
||
List[Dict]: 查询结果数据集,失败返回None
|
||
"""
|
||
try:
|
||
with MySQLHelper(**self.db_config) as db:
|
||
# 查询流通股本数据
|
||
data = db.execute_query(f"""
|
||
SELECT stock_code, stock_name, float_share
|
||
FROM {table_name}
|
||
ORDER BY stock_code
|
||
""")
|
||
|
||
if not data:
|
||
self.logger.error(f"表 {table_name} 中没有流通股本数据")
|
||
return None
|
||
|
||
return data
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
|
||
return None
|
||
|
||
def export_to_csv(self, data: List[Dict], output_file: str) -> bool:
|
||
"""
|
||
将合并后的数据导出到CSV文件
|
||
|
||
Args:
|
||
data: 要导出的数据集
|
||
output_file: 输出的CSV文件路径
|
||
|
||
Returns:
|
||
bool: 是否导出成功
|
||
"""
|
||
if not data:
|
||
self.logger.warning("没有数据可导出")
|
||
return False
|
||
|
||
try:
|
||
# 获取字段名(使用第一个数据的键)
|
||
field_names = list(data[0].keys())
|
||
|
||
with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile:
|
||
writer = csv.DictWriter(csvfile, fieldnames=field_names)
|
||
|
||
# 写入中文表头
|
||
writer.writerow({col: self.head_map.get(col, col) for col in field_names})
|
||
|
||
# 写入数据
|
||
writer.writerows(data)
|
||
|
||
self.logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"导出到CSV失败: {str(e)}")
|
||
return False
|
||
|
||
def export_data(self, monthly_table: str, csv_file: str = None) -> bool:
|
||
"""
|
||
导出合并后的数据到CSV
|
||
|
||
Args:
|
||
monthly_table: 月度均价表名
|
||
csv_file: CSV输出路径(可选)
|
||
|
||
Returns:
|
||
bool: 是否导出成功
|
||
"""
|
||
# 从数据库获取数据
|
||
monthly_data = self.get_monthly_avg_data(monthly_table)
|
||
if not monthly_data:
|
||
self.logger.error("无法获取月度均价数据")
|
||
return False
|
||
|
||
# 导出结果
|
||
file_path = 'data/' + csv_file if csv_file else None
|
||
csv_success = True
|
||
if csv_file:
|
||
csv_success = self.export_to_csv(monthly_data, file_path)
|
||
|
||
return csv_success
|
||
|
||
|