Files
HKDataManagment/DataAnalysis/MarketDataCalculator.py
2025-09-12 10:25:31 +08:00

450 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# """
# 工具包函数
# —— 计算月均流通市值:
# 一个月内,每天流通市值的均值(当日流通股数量 * 当日股价)
# —— 根据需要补充 ...
# """
import pandas as pd
from datetime import datetime
from futu import *
from tqdm import tqdm
from pathlib import Path
from typing import Optional, List, Dict, Union, Tuple
import csv
from typing import List, Dict, Optional
from datetime import datetime
from base import LogHelper, MySQLHelper, ConfigInfo
class MarketDataCalculator:
"""
市场数据计算工具类
功能:
- 计算月均流通市值:一个月内,每天流通市值的均值(当日流通股数量 * 当日股价)
- 数据验证和转换
- 从Futu API获取市场数据
- 从数据库获取股票代码
- 创建和管理数据库表结构
- 导出数据到CSV文件
"""
def __init__(self, db_config: dict, logger_name: str = 'Calculate'):
"""
初始化市场数据计算器
Args:
db_config: 数据库配置字典
logger_name: 日志记录器名称
"""
self.db_config = db_config
self.logger = LogHelper(logger_name=logger_name).setup()
self.month_ranges = ConfigInfo.MONTH_RANGES
self.head_map = ConfigInfo.HEADER_MAP
def create_monthly_avg_table(self, target_table: str = "monthly_close_avg") -> bool:
"""
创建专门存储2024年10月至2025年8月月度均值的表结构
Args:
target_table: 目标表名
Returns:
bool: 是否创建成功
"""
try:
with MySQLHelper(**self.db_config) as db:
create_sql = f"""
CREATE TABLE IF NOT EXISTS {target_table} (
id INT AUTO_INCREMENT PRIMARY KEY,
stock_code VARCHAR(20) NOT NULL COMMENT '股票代码',
stock_name VARCHAR(50) COMMENT '股票名称',
ym_2501 DECIMAL(20, 5) COMMENT '2025年01月',
ym_2502 DECIMAL(20, 5) COMMENT '2025年02月',
ym_2503 DECIMAL(20, 5) COMMENT '2025年03月',
ym_2504 DECIMAL(20, 5) COMMENT '2025年04月',
ym_2505 DECIMAL(20, 5) COMMENT '2025年05月',
ym_2506 DECIMAL(20, 5) COMMENT '2025年06月',
ym_2507 DECIMAL(20, 5) COMMENT '2025年07月',
ym_2508 DECIMAL(20, 5) COMMENT '2025年08月',
ym_2509 DECIMAL(20, 5) COMMENT '2025年09月',
avg_all DECIMAL(20, 5) COMMENT '月间均值',
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
UNIQUE KEY uk_stock_code (stock_code)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='月均流通市值表(2024.10-2025.08)'
"""
db.execute_update(create_sql)
# self.logger.info(f"创建/确认表 {target_table} 结构成功")
return True
except Exception as e:
self.logger.error(f"创建表失败: {str(e)}")
return False
def calculate_and_save_monthly_avg(self,
stock_code: str = "code",
target_table: str = "monthly_close_avg") -> bool:
"""
计算并保存2024年10月至2025年8月的月均流通市值
Args:
source_table: 源数据表名
target_table: 目标表名
Returns:
bool: 是否成功
"""
try:
# 确保表结构存在
if not self.create_monthly_avg_table(target_table):
return False
with MySQLHelper(**self.db_config) as db:
# 获取所有股票代码和名称
sql = """
SELECT stock_code, stock_name
FROM stock_filter
WHERE stock_code = %s
"""
# stock_filter表格可以当成标准表
stock_info = db.execute_query(sql,(stock_code))
if len(stock_info) == 0:
return
stock_code = stock_info[0]['stock_code']
stock_name = stock_info[0]['stock_name']
monthly_data = {'stock_code': stock_code, 'stock_name': stock_name}
# 计算每个月的均值
source_table = 'hk_' + stock_code[3:]
for month_col, (start_date, end_date) in self.month_ranges.items():
sql = """
SELECT AVG(close_price * float_share) as avg_close
FROM {}
WHERE stock_code = %s
AND trade_date BETWEEN %s AND %s
AND close_price IS NOT NULL
AND float_share IS NOT NULL
""".format(source_table)
result = db.execute_query(sql, (stock_code, start_date, end_date))
# 保存小数点后两位,以亿为单位
# monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None
monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None
# 提取所有以 'ym_' 开头的键的值
ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')]
valid_ym_values = [value for value in ym_values if value is not None]
# 计算全部月的均值
if valid_ym_values:
average = sum(valid_ym_values) / len(valid_ym_values)
monthly_data['avg_all'] = average
self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}")
else:
monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错
self.logger.warning(f"股票 {stock_code} 没有有效的月度数据")
# 插入或更新数据
upsert_sql = f"""
INSERT INTO {target_table} (
stock_code, stock_name,
ym_2501, ym_2502, ym_2503, ym_2504,
ym_2505, ym_2506,ym_2507, ym_2508,
ym_2509,
avg_all
) VALUES (
%(stock_code)s, %(stock_name)s,
%(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s,
%(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s,
%(ym_2509)s,
%(avg_all)s
)
ON DUPLICATE KEY UPDATE
stock_name = VALUES(stock_name),
ym_2501 = VALUES(ym_2501),
ym_2502 = VALUES(ym_2502),
ym_2503 = VALUES(ym_2503),
ym_2504 = VALUES(ym_2504),
ym_2505 = VALUES(ym_2505),
ym_2506 = VALUES(ym_2506),
ym_2507 = VALUES(ym_2507),
ym_2508 = VALUES(ym_2508),
ym_2509 = VALUES(ym_2509),
avg_all = VALUES(avg_all),
update_time = CURRENT_TIMESTAMP
"""
db.execute_update(upsert_sql, monthly_data)
# self.logger.info("月度均值计算和保存完成")
return True
except Exception as e:
self.logger.error(f"计算和保存月度均值失败: {str(e)}")
return False
@staticmethod
def safe_float(v) -> Optional[float]:
"""安全转换为float处理N/A和空值"""
try:
return float(v) if pd.notna(v) and str(v).upper() != 'N/A' else None
except (ValueError, TypeError):
return None
@staticmethod
def safe_int(v) -> Optional[int]:
"""安全转换为int处理N/A和空值"""
try:
return int(v) if pd.notna(v) and str(v).upper() != 'N/A' else None
except (ValueError, TypeError):
return None
@staticmethod
def safe_parse_date(date_str, date_format='%Y-%m-%d'):
"""
安全解析日期字符串
:param date_str: 日期字符串
:param date_format: 日期格式
:return: 解析后的datetime对象或None
"""
if not date_str or pd.isna(date_str) or str(date_str).strip() == '':
return None
try:
return datetime.strptime(str(date_str), date_format)
except ValueError:
return None
def validate_market_data(self, dataset: list) -> list:
"""
验证市场数据有效性
Args:
dataset (list): 原始数据集
Returns:
list: 通过验证的数据集
"""
validated_data = []
for item in dataset:
try:
# 必要字段检查
if not item.get('code') or not item.get('name'):
self.logger.warning(f"跳过无效数据: 缺少必要字段 code或name")
continue
# 筛选股票名称
if item.get('name')[-1] == 'R':
continue
# 数值范围验证
if item.get('lot_size') is not None and item['lot_size'] < 0:
self.logger.warning(f"股票 {item['code']} 的lot_size为负值: {item['lot_size']}")
item['lot_size'] = None
validated_data.append(item)
except Exception as e:
self.logger.warning(f"数据验证失败,跳过记录 {item.get('code')}: {str(e)}")
continue
return validated_data
def get_market_data(self, market: Market) -> List[str]:
"""
从Futu API获取指定市场的股票代码列表
Args:
market (Market): 市场枚举值,如 Market.SH, Market.SZ
Returns:
List[str]: 股票代码列表
"""
quote_ctx = OpenQuoteContext(host='127.0.0.1', port=11111)
try:
ret, data = quote_ctx.get_stock_basicinfo(market, SecurityType.STOCK)
if ret == RET_OK:
# 提取code列并转换为列表
codes = data['code'].astype(str).tolist()
self.logger.info(f"获取到 {market} 市场 {len(codes)} 个股票代码")
return codes
else:
self.logger.error(f"获取股票代码失败: {data}")
return []
except Exception as e:
self.logger.error(f"获取股票代码时发生异常: {str(e)}")
return []
finally:
quote_ctx.close()
def get_stock_codes(self) -> List[str]:
"""从 stock_filter 表获取所有股票代码,使用筛选接口得到的股票列表数据"""
try:
with MySQLHelper(**self.db_config) as db:
sql = f"SELECT DISTINCT stock_code,stock_name FROM stock_filter"
results = db.execute_query(sql)
return [
row['stock_code']
for row in results
if row['stock_code'] and (row.get('stock_name', '') and not (row.get('stock_name') and str(row['stock_name'])[-1] == 'R'))
]
except Exception as e:
self.logger.error(f"获取股票代码失败: {str(e)}")
return []
@staticmethod
def read_stock_codes_list(file_path='Reservedcode.txt'):
"""基础读取方法 - 按行读取所有内容"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 去除每行末尾的换行符,并过滤空行
codes = [line.strip() for line in lines if line.strip()]
return codes
except FileNotFoundError:
print(f"文件 {file_path} 不存在")
return []
except Exception as e:
print(f"读取文件失败: {str(e)}")
return []
def get_monthly_avg_data(self, table_name: str) -> Optional[List[Dict]]:
"""
从数据库读取月度均值数据
Args:
table_name: 源数据表名
Returns:
List[Dict]: 查询结果数据集失败返回None
"""
try:
with MySQLHelper(**self.db_config) as db:
# 获取表结构信息
columns = db.execute_query(f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION
""", (self.db_config['database'], table_name))
if not columns:
self.logger.error(f"{table_name} 不存在或没有列")
return None
# 获取列名列表(排除id和update_time)
field_names = [col['COLUMN_NAME'] for col in columns
if col['COLUMN_NAME'] not in ('id', 'update_time')]
# 查询数据
data = db.execute_query(f"""
SELECT {', '.join(field_names)}
FROM {table_name}
ORDER BY stock_code
""")
if not data:
self.logger.error(f"{table_name} 中没有数据")
return None
return data
except Exception as e:
self.logger.error(f"从数据库读取月度均值数据失败: {str(e)}")
return None
def get_float_share_data(self, table_name: str) -> Optional[List[Dict]]:
"""
从conditionalselection表读取流通股本数据
Args:
table_name: 源数据表名
Returns:
List[Dict]: 查询结果数据集失败返回None
"""
try:
with MySQLHelper(**self.db_config) as db:
# 查询流通股本数据
data = db.execute_query(f"""
SELECT stock_code, stock_name, float_share
FROM {table_name}
ORDER BY stock_code
""")
if not data:
self.logger.error(f"{table_name} 中没有流通股本数据")
return None
return data
except Exception as e:
self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
return None
def export_to_csv(self, data: List[Dict], output_file: str) -> bool:
"""
将合并后的数据导出到CSV文件
Args:
data: 要导出的数据集
output_file: 输出的CSV文件路径
Returns:
bool: 是否导出成功
"""
if not data:
self.logger.warning("没有数据可导出")
return False
try:
# 获取字段名(使用第一个数据的键)
field_names = list(data[0].keys())
with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=field_names)
# 写入中文表头
writer.writerow({col: self.head_map.get(col, col) for col in field_names})
# 写入数据
writer.writerows(data)
self.logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}")
return True
except Exception as e:
self.logger.error(f"导出到CSV失败: {str(e)}")
return False
def export_data(self, monthly_table: str, csv_file: str = None) -> bool:
"""
导出合并后的数据到CSV
Args:
monthly_table: 月度均价表名
csv_file: CSV输出路径(可选)
Returns:
bool: 是否导出成功
"""
# 从数据库获取数据
monthly_data = self.get_monthly_avg_data(monthly_table)
if not monthly_data:
self.logger.error("无法获取月度均价数据")
return False
# 导出结果
file_path = 'data/' + csv_file if csv_file else None
csv_success = True
if csv_file:
csv_success = self.export_to_csv(monthly_data, file_path)
return csv_success