# """ # 工具包函数 # —— 计算月均流通市值: # 一个月内,每天流通市值的均值(当日流通股数量 * 当日股价) # —— 根据需要补充 ... # """ # import pandas as pd # from datetime import datetime # from futu import * # from tqdm import tqdm # from pathlib import Path # from base.MySQLHelper import MySQLHelper # from typing import Optional, List, Dict, Union, Tuple # from base.LogHelper import LogHelper # # 基本用法(自动创建日期日志+控制台输出) # logger = LogHelper(logger_name = 'Calculate').setup() # def create_monthly_avg_table(db_config: dict, target_table: str = "monthly_close_avg") -> bool: # """ # 创建专门存储2024年10月至2025年8月月度均值的表结构 -> 后面再根据实际需要,设计通用表格 # Args: # db_config: 数据库配置 # target_table: 目标表名 # Returns: # bool: 是否创建成功 # """ # try: # with MySQLHelper(**db_config) as db: # create_sql = f""" # CREATE TABLE IF NOT EXISTS {target_table} ( # id INT AUTO_INCREMENT PRIMARY KEY, # stock_code VARCHAR(20) NOT NULL COMMENT '股票代码', # stock_name VARCHAR(50) COMMENT '股票名称', # ym_2410 DECIMAL(20, 5) COMMENT '2024年10月', # ym_2411 DECIMAL(20, 5) COMMENT '2024年11月', # ym_2412 DECIMAL(20, 5) COMMENT '2024年12月', # ym_2501 DECIMAL(20, 5) COMMENT '2025年01月', # ym_2502 DECIMAL(20, 5) COMMENT '2025年02月', # ym_2503 DECIMAL(20, 5) COMMENT '2025年03月', # ym_2504 DECIMAL(20, 5) COMMENT '2025年04月', # ym_2505 DECIMAL(20, 5) COMMENT '2025年05月', # ym_2506 DECIMAL(20, 5) COMMENT '2025年06月', # ym_2507 DECIMAL(20, 5) COMMENT '2025年07月', # ym_2508 DECIMAL(20, 5) COMMENT '2025年08月', # avg_all DECIMAL(20, 5) COMMENT '月间均值', # update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', # UNIQUE KEY uk_stock_code (stock_code) # ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='月均流通市值表(2024.10-2025.08)' # """ # db.execute_update(create_sql) # logger.info(f"创建/确认表 {target_table} 结构成功") # return True # except Exception as e: # logger.error(f"创建表失败: {str(e)}") # return False # def calculate_and_save_monthly_avg(db_config: dict, # source_table: str = "stock_quotes", # target_table: str = "monthly_close_avg") -> bool: # """ # 计算并保存2024年10月至2025年8月的月均流通市值 -> 后面修改为,指定时间间隔 # Args: # db_config: 数据库配置 # source_table: 源数据表名 # target_table: 目标表名 # Returns: # bool: 是否成功 # """ # # 定义分析的时间范围 -> 根据实际需要进行调整时间范围 # month_ranges = { # 'ym_2410': ('2024-10-01', '2024-10-31'), # 'ym_2411': ('2024-11-01', '2024-11-30'), # 'ym_2412': ('2024-12-01', '2024-12-31'), # 'ym_2501': ('2025-01-01', '2025-01-31'), # 'ym_2502': ('2025-02-01', '2025-02-28'), # 'ym_2503': ('2025-03-01', '2025-03-31'), # 'ym_2504': ('2025-04-01', '2025-04-30'), # 'ym_2505': ('2025-05-01', '2025-05-31'), # 'ym_2506': ('2025-06-01', '2025-06-30'), # 'ym_2507': ('2025-07-01', '2025-07-31'), # 'ym_2508': ('2025-08-01', '2025-08-31') # } # try: # # 确保表结构存在 # if not create_monthly_avg_table(db_config, target_table): # return False # with MySQLHelper(**db_config) as db: # # 获取所有股票代码和名称 # stock_info = db.execute_query( # f"SELECT DISTINCT stock_code, stock_name FROM {source_table}" # ) # if not stock_info: # logging.error("没有获取到股票基本信息") # return False # # 为每只股票计算各月均值 # for stock in stock_info: # stock_code = stock['stock_code'] # stock_name = stock['stock_name'] # monthly_data = {'stock_code': stock_code, 'stock_name': stock_name} # # 计算每个月的均值 # for month_col, (start_date, end_date) in month_ranges.items(): # sql = """ # SELECT AVG(close_price * float_share) as avg_close # FROM {} # WHERE stock_code = %s # AND trade_date BETWEEN %s AND %s # AND close_price IS NOT NULL # AND float_share IS NOT NULL # """.format(source_table) # result = db.execute_query(sql, (stock_code, start_date, end_date)) # monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None # 流通股数量单位:1000 -> 可以考虑直接按照亿为单位存储 # # 提取所有以 'ym_' 开头的键的值 # ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')] # valid_ym_values = [value for value in ym_values if value is not None] # # 计算全部月的均值 # if valid_ym_values: # average = sum(valid_ym_values) / len(valid_ym_values) # monthly_data['avg_all'] = average # logger.info(f"月间流通市值平均值为: {average}") # else: # logger.error("没有找到以 'ym_' 开头的键") # # 插入或更新数据 # upsert_sql = f""" # INSERT INTO {target_table} ( # stock_code, stock_name, # ym_2410, ym_2411, ym_2412, # ym_2501, ym_2502, ym_2503, ym_2504, # ym_2505, ym_2506, ym_2507, ym_2508, # avg_all # ) VALUES ( # %(stock_code)s, %(stock_name)s, # %(ym_2410)s, %(ym_2411)s, %(ym_2412)s, # %(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s, # %(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s, # %(avg_all)s # ) # ON DUPLICATE KEY UPDATE # stock_name = VALUES(stock_name), # ym_2410 = VALUES(ym_2410), # ym_2411 = VALUES(ym_2411), # ym_2412 = VALUES(ym_2412), # ym_2501 = VALUES(ym_2501), # ym_2502 = VALUES(ym_2502), # ym_2503 = VALUES(ym_2503), # ym_2504 = VALUES(ym_2504), # ym_2505 = VALUES(ym_2505), # ym_2506 = VALUES(ym_2506), # ym_2507 = VALUES(ym_2507), # ym_2508 = VALUES(ym_2508), # avg_all = VALUES(avg_all), # update_time = CURRENT_TIMESTAMP # """ # db.execute_update(upsert_sql, monthly_data) # logger.info("月度均值计算和保存完成") # return True # except Exception as e: # logger.error(f"计算和保存月度均值失败: {str(e)}") # return False # # 安全转换函数 # def safe_float(v) -> Optional[float]: # """安全转换为float,处理N/A和空值""" # try: # return float(v) if pd.notna(v) and str(v).upper() != 'N/A' else None # except (ValueError, TypeError): # return None # def safe_int(v) -> Optional[int]: # """安全转换为int,处理N/A和空值""" # try: # return int(v) if pd.notna(v) and str(v).upper() != 'N/A' else None # except (ValueError, TypeError): # return None # def safe_parse_date(date_str, date_format='%Y-%m-%d'): # """ # 安全解析日期字符串 # :param date_str: 日期字符串 # :param date_format: 日期格式 # :return: 解析后的datetime对象或None # """ # if not date_str or pd.isna(date_str) or str(date_str).strip() == '': # return None # try: # return datetime.strptime(str(date_str), date_format) # except ValueError: # logger.warning(f"无法解析日期字符串: {date_str}") # return None # def validate_market_data(dataset: list) -> list: # """ # 验证市场数据有效性 # Args: # dataset (list): 原始数据集 # Returns: # list: 通过验证的数据集 # """ # validated_data = [] # for item in dataset: # try: # # 必要字段检查 # if not item.get('code') or not item.get('name'): # logger.warning(f"跳过无效数据: 缺少必要字段 code或name") # continue # # 筛选股票名称 # if item.get('name')[-1] == 'R': # continue # # 数值范围验证 # if item.get('lot_size') is not None and item['lot_size'] < 0: # logger.warning(f"股票 {item['code']} 的lot_size为负值: {item['lot_size']}") # item['lot_size'] = None # validated_data.append(item) # except Exception as e: # logger.warning(f"数据验证失败,跳过记录 {item.get('code')}: {str(e)}") # continue # return validated_data # def get_market_data(market: Market) -> List[str]: # """ # 从Futu API获取指定市场的股票代码列表 # Args: # market (Market): 市场枚举值,如 Market.SH, Market.SZ # Returns: # List[str]: 股票代码列表 # """ # quote_ctx = OpenQuoteContext(host='127.0.0.1', port=11111) # try: # ret, data = quote_ctx.get_stock_basicinfo(market, SecurityType.STOCK) # if ret == RET_OK: # # 提取code列并转换为列表 # codes = data['code'].astype(str).tolist() # logger.info(f"获取到 {market} 市场 {len(codes)} 个股票代码") # return codes # else: # logger.error(f"获取股票代码失败: {data}") # return [] # except Exception as e: # logger.error(f"获取股票代码时发生异常: {str(e)}") # return [] # finally: # quote_ctx.close() # def get_stock_codes() -> List[str]: # """从conditionalselection表获取所有股票代码""" # try: # with MySQLHelper(**db_config) as db: # sql = f"SELECT DISTINCT stock_code,stock_name FROM stock_filter" # results = db.execute_query(sql) # return [ # row['stock_code'] # for row in results # if row['stock_code'] and (row.get('stock_name', '') and not (row.get('stock_name') and str(row['stock_name'])[-1] == 'R')) # 排除 name,上一个版本的排除了R结尾的股票,实际上多排除了3个,这里改成全部计算,导出的时候进行筛选处理 # ] # except Exception as e: # logger.error(f"获取股票代码失败: {str(e)}") # return [] # def read_stock_codes_list(file_path='Reservedcode.txt'): # """基础读取方法 - 按行读取所有内容""" # try: # with open(file_path, 'r', encoding='utf-8') as f: # lines = f.readlines() # # 去除每行末尾的换行符,并过滤空行 # codes = [line.strip() for line in lines if line.strip()] # return codes # except FileNotFoundError: # print(f"文件 {file_path} 不存在") # return [] # except Exception as e: # print(f"读取文件失败: {str(e)}") # return [] # # 数据库配置 # db_config = { # 'host': 'localhost', # 'user': 'root', # 'password': 'bzskmysql', # 'database': 'hk_kline_1d' # } import pandas as pd from datetime import datetime from futu import * from tqdm import tqdm from pathlib import Path from base.MySQLHelper import MySQLHelper from typing import Optional, List, Dict, Union, Tuple from base.LogHelper import LogHelper import csv from typing import List, Dict, Optional from datetime import datetime from base import LogHelper, MySQLHelper class MarketDataCalculator: """ 市场数据计算工具类 功能: - 计算月均流通市值:一个月内,每天流通市值的均值(当日流通股数量 * 当日股价) - 数据验证和转换 - 从Futu API获取市场数据 - 从数据库获取股票代码 - 创建和管理数据库表结构 - 导出数据到CSV文件 """ # 表头映射配置 HEADER_MAP = { 'stock_code': '股票代码', 'stock_name': '股票名称', 'ym_2410': '2024年10月', 'ym_2411': '2024年11月', 'ym_2412': '2024年12月', 'ym_2501': '2025年01月', 'ym_2502': '2025年02月', 'ym_2503': '2025年03月', 'ym_2504': '2025年04月', 'ym_2505': '2025年05月', 'ym_2506': '2025年06月', 'ym_2507': '2025年07月', 'ym_2508': '2025年08月', 'avg_all': '月度均值' } # 月份范围配置 MONTH_RANGES = { 'ym_2410': ('2024-10-01', '2024-10-31'), 'ym_2411': ('2024-11-01', '2024-11-30'), 'ym_2412': ('2024-12-01', '2024-12-31'), 'ym_2501': ('2025-01-01', '2025-01-31'), 'ym_2502': ('2025-02-01', '2025-02-28'), 'ym_2503': ('2025-03-01', '2025-03-31'), 'ym_2504': ('2025-04-01', '2025-04-30'), 'ym_2505': ('2025-05-01', '2025-05-31'), 'ym_2506': ('2025-06-01', '2025-06-30'), 'ym_2507': ('2025-07-01', '2025-07-31'), 'ym_2508': ('2025-08-01', '2025-08-31') } def __init__(self, db_config: dict, logger_name: str = 'Calculate'): """ 初始化市场数据计算器 Args: db_config: 数据库配置字典 logger_name: 日志记录器名称 """ self.db_config = db_config self.logger = LogHelper(logger_name=logger_name).setup() def create_monthly_avg_table(self, target_table: str = "monthly_close_avg") -> bool: """ 创建专门存储2024年10月至2025年8月月度均值的表结构 Args: target_table: 目标表名 Returns: bool: 是否创建成功 """ try: with MySQLHelper(**self.db_config) as db: create_sql = f""" CREATE TABLE IF NOT EXISTS {target_table} ( id INT AUTO_INCREMENT PRIMARY KEY, stock_code VARCHAR(20) NOT NULL COMMENT '股票代码', stock_name VARCHAR(50) COMMENT '股票名称', ym_2410 DECIMAL(20, 5) COMMENT '2024年10月', ym_2411 DECIMAL(20, 5) COMMENT '2024年11月', ym_2412 DECIMAL(20, 5) COMMENT '2024年12月', ym_2501 DECIMAL(20, 5) COMMENT '2025年01月', ym_2502 DECIMAL(20, 5) COMMENT '2025年02月', ym_2503 DECIMAL(20, 5) COMMENT '2025年03月', ym_2504 DECIMAL(20, 5) COMMENT '2025年04月', ym_2505 DECIMAL(20, 5) COMMENT '2025年05月', ym_2506 DECIMAL(20, 5) COMMENT '2025年06月', ym_2507 DECIMAL(20, 5) COMMENT '2025年07月', ym_2508 DECIMAL(20, 5) COMMENT '2025年08月', avg_all DECIMAL(20, 5) COMMENT '月间均值', update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', UNIQUE KEY uk_stock_code (stock_code) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='月均流通市值表(2024.10-2025.08)' """ db.execute_update(create_sql) self.logger.info(f"创建/确认表 {target_table} 结构成功") return True except Exception as e: self.logger.error(f"创建表失败: {str(e)}") return False def calculate_and_save_monthly_avg(self, source_table: str = "stock_quotes", target_table: str = "monthly_close_avg") -> bool: """ 计算并保存2024年10月至2025年8月的月均流通市值 Args: source_table: 源数据表名 target_table: 目标表名 Returns: bool: 是否成功 """ try: # 确保表结构存在 if not self.create_monthly_avg_table(target_table): return False with MySQLHelper(**self.db_config) as db: # 获取所有股票代码和名称 stock_info = db.execute_query( f"SELECT DISTINCT stock_code, stock_name FROM {source_table}" ) if not stock_info: self.logger.error("没有获取到股票基本信息") return False # 为每只股票计算各月均值 for stock in stock_info: stock_code = stock['stock_code'] stock_name = stock['stock_name'] monthly_data = {'stock_code': stock_code, 'stock_name': stock_name} # 计算每个月的均值 for month_col, (start_date, end_date) in self.MONTH_RANGES.items(): sql = """ SELECT AVG(close_price * float_share) as avg_close FROM {} WHERE stock_code = %s AND trade_date BETWEEN %s AND %s AND close_price IS NOT NULL AND float_share IS NOT NULL """.format(source_table) result = db.execute_query(sql, (stock_code, start_date, end_date)) monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None # 提取所有以 'ym_' 开头的键的值 ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')] valid_ym_values = [value for value in ym_values if value is not None] # 计算全部月的均值 if valid_ym_values: average = sum(valid_ym_values) / len(valid_ym_values) monthly_data['avg_all'] = average self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}") else: self.logger.warning(f"股票 {stock_code} 没有有效的月度数据") # 插入或更新数据 upsert_sql = f""" INSERT INTO {target_table} ( stock_code, stock_name, ym_2410, ym_2411, ym_2412, ym_2501, ym_2502, ym_2503, ym_2504, ym_2505, ym_2506, ym_2507, ym_2508, avg_all ) VALUES ( %(stock_code)s, %(stock_name)s, %(ym_2410)s, %(ym_2411)s, %(ym_2412)s, %(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s, %(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s, %(avg_all)s ) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), ym_2410 = VALUES(ym_2410), ym_2411 = VALUES(ym_2411), ym_2412 = VALUES(ym_2412), ym_2501 = VALUES(ym_2501), ym_2502 = VALUES(ym_2502), ym_2503 = VALUES(ym_2503), ym_2504 = VALUES(ym_2504), ym_2505 = VALUES(ym_2505), ym_2506 = VALUES(ym_2506), ym_2507 = VALUES(ym_2507), ym_2508 = VALUES(ym_2508), avg_all = VALUES(avg_all), update_time = CURRENT_TIMESTAMP """ db.execute_update(upsert_sql, monthly_data) self.logger.info("月度均值计算和保存完成") return True except Exception as e: self.logger.error(f"计算和保存月度均值失败: {str(e)}") return False @staticmethod def safe_float(v) -> Optional[float]: """安全转换为float,处理N/A和空值""" try: return float(v) if pd.notna(v) and str(v).upper() != 'N/A' else None except (ValueError, TypeError): return None @staticmethod def safe_int(v) -> Optional[int]: """安全转换为int,处理N/A和空值""" try: return int(v) if pd.notna(v) and str(v).upper() != 'N/A' else None except (ValueError, TypeError): return None @staticmethod def safe_parse_date(date_str, date_format='%Y-%m-%d'): """ 安全解析日期字符串 :param date_str: 日期字符串 :param date_format: 日期格式 :return: 解析后的datetime对象或None """ if not date_str or pd.isna(date_str) or str(date_str).strip() == '': return None try: return datetime.strptime(str(date_str), date_format) except ValueError: return None def validate_market_data(self, dataset: list) -> list: """ 验证市场数据有效性 Args: dataset (list): 原始数据集 Returns: list: 通过验证的数据集 """ validated_data = [] for item in dataset: try: # 必要字段检查 if not item.get('code') or not item.get('name'): self.logger.warning(f"跳过无效数据: 缺少必要字段 code或name") continue # 筛选股票名称 if item.get('name')[-1] == 'R': continue # 数值范围验证 if item.get('lot_size') is not None and item['lot_size'] < 0: self.logger.warning(f"股票 {item['code']} 的lot_size为负值: {item['lot_size']}") item['lot_size'] = None validated_data.append(item) except Exception as e: self.logger.warning(f"数据验证失败,跳过记录 {item.get('code')}: {str(e)}") continue return validated_data def get_market_data(self, market: Market) -> List[str]: """ 从Futu API获取指定市场的股票代码列表 Args: market (Market): 市场枚举值,如 Market.SH, Market.SZ Returns: List[str]: 股票代码列表 """ quote_ctx = OpenQuoteContext(host='127.0.0.1', port=11111) try: ret, data = quote_ctx.get_stock_basicinfo(market, SecurityType.STOCK) if ret == RET_OK: # 提取code列并转换为列表 codes = data['code'].astype(str).tolist() self.logger.info(f"获取到 {market} 市场 {len(codes)} 个股票代码") return codes else: self.logger.error(f"获取股票代码失败: {data}") return [] except Exception as e: self.logger.error(f"获取股票代码时发生异常: {str(e)}") return [] finally: quote_ctx.close() def get_stock_codes(self) -> List[str]: """从conditionalselection表获取所有股票代码""" try: with MySQLHelper(**self.db_config) as db: sql = f"SELECT DISTINCT stock_code,stock_name FROM stocks_hk" results = db.execute_query(sql) return [ row['stock_code'] for row in results if row['stock_code'] and (row.get('stock_name', '') and not (row.get('stock_name') and str(row['stock_name'])[-1] == 'R')) ] except Exception as e: self.logger.error(f"获取股票代码失败: {str(e)}") return [] @staticmethod def read_stock_codes_list(file_path='Reservedcode.txt'): """基础读取方法 - 按行读取所有内容""" try: with open(file_path, 'r', encoding='utf-8') as f: lines = f.readlines() # 去除每行末尾的换行符,并过滤空行 codes = [line.strip() for line in lines if line.strip()] return codes except FileNotFoundError: print(f"文件 {file_path} 不存在") return [] except Exception as e: print(f"读取文件失败: {str(e)}") return [] def get_monthly_avg_data(self, table_name: str) -> Optional[List[Dict]]: """ 从数据库读取月度均值数据 Args: table_name: 源数据表名 Returns: List[Dict]: 查询结果数据集,失败返回None """ try: with MySQLHelper(**self.db_config) as db: # 获取表结构信息 columns = db.execute_query(f""" SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION """, (self.db_config['database'], table_name)) if not columns: self.logger.error(f"表 {table_name} 不存在或没有列") return None # 获取列名列表(排除id和update_time) field_names = [col['COLUMN_NAME'] for col in columns if col['COLUMN_NAME'] not in ('id', 'update_time')] # 查询数据 data = db.execute_query(f""" SELECT {', '.join(field_names)} FROM {table_name} ORDER BY stock_code """) if not data: self.logger.error(f"表 {table_name} 中没有数据") return None return data except Exception as e: self.logger.error(f"从数据库读取月度均值数据失败: {str(e)}") return None def get_float_share_data(self, table_name: str) -> Optional[List[Dict]]: """ 从conditionalselection表读取流通股本数据 Args: table_name: 源数据表名 Returns: List[Dict]: 查询结果数据集,失败返回None """ try: with MySQLHelper(**self.db_config) as db: # 查询流通股本数据 data = db.execute_query(f""" SELECT stock_code, stock_name, float_share FROM {table_name} ORDER BY stock_code """) if not data: self.logger.error(f"表 {table_name} 中没有流通股本数据") return None return data except Exception as e: self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}") return None def export_to_csv(self, data: List[Dict], output_file: str) -> bool: """ 将合并后的数据导出到CSV文件 Args: data: 要导出的数据集 output_file: 输出的CSV文件路径 Returns: bool: 是否导出成功 """ if not data: self.logger.warning("没有数据可导出") return False try: # 获取字段名(使用第一个数据的键) field_names = list(data[0].keys()) with open(output_file, mode='w', newline='', encoding='utf-8-sig') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=field_names) # 写入中文表头 writer.writerow({col: self.HEADER_MAP.get(col, col) for col in field_names}) # 写入数据 writer.writerows(data) self.logger.info(f"成功导出 {len(data)} 条记录到CSV文件: {output_file}") return True except Exception as e: self.logger.error(f"导出到CSV失败: {str(e)}") return False def export_data(self, monthly_table: str, csv_file: str = None) -> bool: """ 导出合并后的数据到CSV Args: monthly_table: 月度均价表名 csv_file: CSV输出路径(可选) Returns: bool: 是否导出成功 """ # 从数据库获取数据 monthly_data = self.get_monthly_avg_data(monthly_table) if not monthly_data: self.logger.error("无法获取月度均价数据") return False # 导出结果 file_path = 'data/' + csv_file if csv_file else None csv_success = True if csv_file: csv_success = self.export_to_csv(monthly_data, file_path) return csv_success