update 数据导出优化
This commit is contained in:
454
DataAnalysis/MarketDataCalculator_2024.py
Normal file
454
DataAnalysis/MarketDataCalculator_2024.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# """
|
||||
# 工具包函数
|
||||
|
||||
# —— 计算月均流通市值:
|
||||
# 一个月内,每天流通市值的均值(当日流通股数量 * 当日股价)
|
||||
# —— 根据需要补充 ...
|
||||
|
||||
# """
|
||||
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from futu import *
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Union, Tuple
|
||||
import csv
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from base import LogHelper, MySQLHelper, ConfigInfo
|
||||
|
||||
|
||||
class MarketDataCalculator:
|
||||
"""
|
||||
市场数据计算工具类
|
||||
|
||||
功能:
|
||||
- 计算月均流通市值:一个月内,每天流通市值的均值(当日流通股数量 * 当日股价)
|
||||
- 数据验证和转换
|
||||
- 从Futu API获取市场数据
|
||||
- 从数据库获取股票代码
|
||||
- 创建和管理数据库表结构
|
||||
- 导出数据到CSV文件
|
||||
"""
|
||||
|
||||
def __init__(self, db_config: dict, logger_name: str = 'Calculate'):
|
||||
"""
|
||||
初始化市场数据计算器
|
||||
|
||||
Args:
|
||||
db_config: 数据库配置字典
|
||||
logger_name: 日志记录器名称
|
||||
"""
|
||||
self.db_config = db_config
|
||||
self.logger = LogHelper(logger_name=logger_name).setup()
|
||||
self.month_ranges = ConfigInfo.MONTH_RANGES
|
||||
self.head_map = ConfigInfo.HEADER_MAP
|
||||
|
||||
def create_monthly_avg_table(self, target_table: str = "monthly_close_avg") -> bool:
|
||||
"""
|
||||
创建专门存储2024年10月至2024年8月月度均值的表结构
|
||||
|
||||
Args:
|
||||
target_table: 目标表名
|
||||
|
||||
Returns:
|
||||
bool: 是否创建成功
|
||||
"""
|
||||
try:
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
create_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {target_table} (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
stock_code VARCHAR(20) NOT NULL COMMENT '股票代码',
|
||||
stock_name VARCHAR(50) COMMENT '股票名称',
|
||||
ym_2401 DECIMAL(20, 5) COMMENT '2024年01月',
|
||||
ym_2402 DECIMAL(20, 5) COMMENT '2024年02月',
|
||||
ym_2403 DECIMAL(20, 5) COMMENT '2024年03月',
|
||||
ym_2404 DECIMAL(20, 5) COMMENT '2024年04月',
|
||||
ym_2405 DECIMAL(20, 5) COMMENT '2024年05月',
|
||||
ym_2406 DECIMAL(20, 5) COMMENT '2024年06月',
|
||||
ym_2407 DECIMAL(20, 5) COMMENT '2024年07月',
|
||||
ym_2408 DECIMAL(20, 5) COMMENT '2024年08月',
|
||||
ym_2409 DECIMAL(20, 5) COMMENT '2024年09月',
|
||||
ym_2410 DECIMAL(20, 5) COMMENT '2024年10月',
|
||||
ym_2411 DECIMAL(20, 5) COMMENT '2024年11月',
|
||||
ym_2412 DECIMAL(20, 5) COMMENT '2024年12月',
|
||||
avg_all DECIMAL(20, 5) COMMENT '月间均值',
|
||||
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
UNIQUE KEY uk_stock_code (stock_code)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='月均流通市值表(2024.10-2024.08)'
|
||||
"""
|
||||
db.execute_update(create_sql)
|
||||
# self.logger.info(f"创建/确认表 {target_table} 结构成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"创建表失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def calculate_and_save_monthly_avg(self,
|
||||
stock_code: str = "code",
|
||||
target_table: str = "monthly_close_avg") -> bool:
|
||||
"""
|
||||
计算并保存2024年10月至2024年8月的月均流通市值
|
||||
|
||||
Args:
|
||||
source_table: 源数据表名
|
||||
target_table: 目标表名
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
try:
|
||||
# 确保表结构存在
|
||||
if not self.create_monthly_avg_table(target_table):
|
||||
return False
|
||||
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
|
||||
# 获取所有股票代码和名称
|
||||
sql = """
|
||||
SELECT stock_code, stock_name
|
||||
FROM stock_filter
|
||||
WHERE stock_code = %s
|
||||
"""
|
||||
|
||||
# stock_filter表格可以当成标准表
|
||||
stock_info = db.execute_query(sql,(stock_code))
|
||||
if len(stock_info) == 0:
|
||||
return
|
||||
|
||||
|
||||
stock_code = stock_info[0]['stock_code']
|
||||
stock_name = stock_info[0]['stock_name']
|
||||
|
||||
monthly_data = {'stock_code': stock_code, 'stock_name': stock_name}
|
||||
|
||||
# 计算每个月的均值
|
||||
source_table = 'hk_' + stock_code[3:]
|
||||
for month_col, (start_date, end_date) in self.month_ranges.items():
|
||||
sql = """
|
||||
SELECT AVG(close_price * float_share) as avg_close
|
||||
FROM {}
|
||||
WHERE stock_code = %s
|
||||
AND trade_date BETWEEN %s AND %s
|
||||
AND close_price IS NOT NULL
|
||||
AND float_share IS NOT NULL
|
||||
""".format(source_table)
|
||||
|
||||
result = db.execute_query(sql, (stock_code, start_date, end_date))
|
||||
# 保存小数点后两位,以亿为单位
|
||||
# monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None
|
||||
monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None
|
||||
|
||||
# 提取所有以 'ym_' 开头的键的值
|
||||
ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')]
|
||||
valid_ym_values = [value for value in ym_values if value is not None]
|
||||
|
||||
# 计算全部月的均值
|
||||
if valid_ym_values:
|
||||
average = sum(valid_ym_values) / len(valid_ym_values)
|
||||
monthly_data['avg_all'] = average
|
||||
self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}")
|
||||
else:
|
||||
monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错
|
||||
self.logger.warning(f"股票 {stock_code} 没有有效的月度数据")
|
||||
|
||||
# 插入或更新数据
|
||||
upsert_sql = f"""
|
||||
INSERT INTO {target_table} (
|
||||
stock_code, stock_name,
|
||||
ym_2401, ym_2402, ym_2403, ym_2404,
|
||||
ym_2405, ym_2406,ym_2407, ym_2408,
|
||||
ym_2409, ym_2410,ym_2411, ym_2412,
|
||||
avg_all
|
||||
) VALUES (
|
||||
%(stock_code)s, %(stock_name)s,
|
||||
%(ym_2401)s, %(ym_2402)s, %(ym_2403)s, %(ym_2404)s,
|
||||
%(ym_2405)s, %(ym_2406)s, %(ym_2407)s, %(ym_2408)s,
|
||||
%(ym_2409)s,%(ym_2410)s, %(ym_2411)s, %(ym_2412)s,
|
||||
%(avg_all)s
|
||||
)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
stock_name = VALUES(stock_name),
|
||||
ym_2401 = VALUES(ym_2401),
|
||||
ym_2402 = VALUES(ym_2402),
|
||||
ym_2403 = VALUES(ym_2403),
|
||||
ym_2404 = VALUES(ym_2404),
|
||||
ym_2405 = VALUES(ym_2405),
|
||||
ym_2406 = VALUES(ym_2406),
|
||||
ym_2407 = VALUES(ym_2407),
|
||||
ym_2408 = VALUES(ym_2408),
|
||||
ym_2409 = VALUES(ym_2409),
|
||||
ym_2410 = VALUES(ym_2410),
|
||||
ym_2411 = VALUES(ym_2411),
|
||||
ym_2412 = VALUES(ym_2412),
|
||||
avg_all = VALUES(avg_all),
|
||||
update_time = CURRENT_TIMESTAMP
|
||||
"""
|
||||
db.execute_update(upsert_sql, monthly_data)
|
||||
# self.logger.info("月度均值计算和保存完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"计算和保存月度均值失败: {str(e)}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def safe_float(v) -> Optional[float]:
|
||||
"""安全转换为float,处理N/A和空值"""
|
||||
try:
|
||||
return float(v) if pd.notna(v) and str(v).upper() != 'N/A' else None
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def safe_int(v) -> Optional[int]:
|
||||
"""安全转换为int,处理N/A和空值"""
|
||||
try:
|
||||
return int(v) if pd.notna(v) and str(v).upper() != 'N/A' else None
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def safe_parse_date(date_str, date_format='%Y-%m-%d'):
|
||||
"""
|
||||
安全解析日期字符串
|
||||
:param date_str: 日期字符串
|
||||
:param date_format: 日期格式
|
||||
:return: 解析后的datetime对象或None
|
||||
"""
|
||||
if not date_str or pd.isna(date_str) or str(date_str).strip() == '':
|
||||
return None
|
||||
try:
|
||||
return datetime.strptime(str(date_str), date_format)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def validate_market_data(self, dataset: list) -> list:
|
||||
"""
|
||||
验证市场数据有效性
|
||||
|
||||
Args:
|
||||
dataset (list): 原始数据集
|
||||
|
||||
Returns:
|
||||
list: 通过验证的数据集
|
||||
"""
|
||||
validated_data = []
|
||||
for item in dataset:
|
||||
try:
|
||||
# 必要字段检查
|
||||
if not item.get('code') or not item.get('name'):
|
||||
self.logger.warning(f"跳过无效数据: 缺少必要字段 code或name")
|
||||
continue
|
||||
|
||||
# 筛选股票名称
|
||||
if item.get('name')[-1] == 'R':
|
||||
continue
|
||||
|
||||
# 数值范围验证
|
||||
if item.get('lot_size') is not None and item['lot_size'] < 0:
|
||||
self.logger.warning(f"股票 {item['code']} 的lot_size为负值: {item['lot_size']}")
|
||||
item['lot_size'] = None
|
||||
|
||||
validated_data.append(item)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"数据验证失败,跳过记录 {item.get('code')}: {str(e)}")
|
||||
continue
|
||||
|
||||
return validated_data
|
||||
|
||||
def get_market_data(self, market: Market) -> List[str]:
|
||||
"""
|
||||
从Futu API获取指定市场的股票代码列表
|
||||
|
||||
Args:
|
||||
market (Market): 市场枚举值,如 Market.SH, Market.SZ
|
||||
|
||||
Returns:
|
||||
List[str]: 股票代码列表
|
||||
"""
|
||||
quote_ctx = OpenQuoteContext(host='127.0.0.1', port=11111)
|
||||
try:
|
||||
ret, data = quote_ctx.get_stock_basicinfo(market, SecurityType.STOCK)
|
||||
if ret == RET_OK:
|
||||
# 提取code列并转换为列表
|
||||
codes = data['code'].astype(str).tolist()
|
||||
self.logger.info(f"获取到 {market} 市场 {len(codes)} 个股票代码")
|
||||
return codes
|
||||
else:
|
||||
self.logger.error(f"获取股票代码失败: {data}")
|
||||
return []
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取股票代码时发生异常: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
quote_ctx.close()
|
||||
|
||||
|
||||
def get_stock_codes(self) -> List[str]:
|
||||
"""从 stock_filter 表获取所有股票代码,使用筛选接口得到的股票列表数据"""
|
||||
try:
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
sql = f"SELECT DISTINCT stock_code,stock_name FROM stock_filter"
|
||||
results = db.execute_query(sql)
|
||||
return [
|
||||
row['stock_code']
|
||||
for row in results
|
||||
if row['stock_code'] and (row.get('stock_name', '') and not (row.get('stock_name') and str(row['stock_name'])[-1] == 'R'))
|
||||
]
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取股票代码失败: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def read_stock_codes_list(file_path='Reservedcode.txt'):
|
||||
"""基础读取方法 - 按行读取所有内容"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
# 去除每行末尾的换行符,并过滤空行
|
||||
codes = [line.strip() for line in lines if line.strip()]
|
||||
return codes
|
||||
except FileNotFoundError:
|
||||
print(f"文件 {file_path} 不存在")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"读取文件失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_monthly_avg_data(self, table_name: str) -> Optional[List[Dict]]:
|
||||
"""
|
||||
从数据库读取月度均值数据
|
||||
|
||||
Args:
|
||||
table_name: 源数据表名
|
||||
|
||||
Returns:
|
||||
List[Dict]: 查询结果数据集,失败返回None
|
||||
"""
|
||||
try:
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
# 获取表结构信息
|
||||
columns = db.execute_query(f"""
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
|
||||
ORDER BY ORDINAL_POSITION
|
||||
""", (self.db_config['database'], table_name))
|
||||
|
||||
if not columns:
|
||||
self.logger.error(f"表 {table_name} 不存在或没有列")
|
||||
return None
|
||||
|
||||
# 获取列名列表(排除id和update_time)
|
||||
field_names = [col['COLUMN_NAME'] for col in columns
|
||||
if col['COLUMN_NAME'] not in ('id', 'update_time')]
|
||||
|
||||
# 查询数据
|
||||
data = db.execute_query(f"""
|
||||
SELECT {', '.join(field_names)}
|
||||
FROM {table_name}
|
||||
ORDER BY stock_code
|
||||
""")
|
||||
|
||||
if not data:
|
||||
self.logger.error(f"表 {table_name} 中没有数据")
|
||||
return None
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"从数据库读取月度均值数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_float_share_data(self, table_name: str) -> Optional[List[Dict]]:
|
||||
"""
|
||||
从conditionalselection表读取流通股本数据
|
||||
|
||||
Args:
|
||||
table_name: 源数据表名
|
||||
|
||||
Returns:
|
||||
List[Dict]: 查询结果数据集,失败返回None
|
||||
"""
|
||||
try:
|
||||
with MySQLHelper(**self.db_config) as db:
|
||||
# 查询流通股本数据
|
||||
data = db.execute_query(f"""
|
||||
SELECT stock_code, stock_name, float_share
|
||||
FROM {table_name}
|
||||
ORDER BY stock_code
|
||||
""")
|
||||
|
||||
if not data:
|
||||
self.logger.error(f"表 {table_name} 中没有流通股本数据")
|
||||
return None
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def export_to_csv(self, data: List[Dict], output_file: str) -> bool:
|
||||
"""
|
||||
将合并后的数据导出到CSV文件
|
||||
|
||||
Args:
|
||||
data: 要导出的数据集
|
||||
output_file: 输出的CSV文件路径
|
||||
|
||||
Returns:
|
||||
bool: 是否导出成功
|
||||
"""
|
||||
if not data:
|
||||
self.logger.warning("没有数据可导出")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取字段名(使用第一个数据的键)
|
||||
field_names = list(data[0].keys())
|
||||
|
||||
with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=field_names)
|
||||
|
||||
# 写入中文表头
|
||||
writer.writerow({col: self.head_map.get(col, col) for col in field_names})
|
||||
|
||||
# 写入数据
|
||||
writer.writerows(data)
|
||||
|
||||
self.logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"导出到CSV失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def export_data(self, monthly_table: str, csv_file: str = None) -> bool:
|
||||
"""
|
||||
导出合并后的数据到CSV
|
||||
|
||||
Args:
|
||||
monthly_table: 月度均价表名
|
||||
csv_file: CSV输出路径(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否导出成功
|
||||
"""
|
||||
# 从数据库获取数据
|
||||
monthly_data = self.get_monthly_avg_data(monthly_table)
|
||||
if not monthly_data:
|
||||
self.logger.error("无法获取月度均价数据")
|
||||
return False
|
||||
|
||||
# 导出结果
|
||||
file_path = 'data/' + csv_file if csv_file else None
|
||||
csv_success = True
|
||||
if csv_file:
|
||||
csv_success = self.export_to_csv(monthly_data, file_path)
|
||||
|
||||
return csv_success
|
||||
|
||||
|
||||
Reference in New Issue
Block a user