diff --git a/DataAnalysis/DataExporter.py b/DataAnalysis/DataExporter.py index 573597f..355e40c 100644 --- a/DataAnalysis/DataExporter.py +++ b/DataAnalysis/DataExporter.py @@ -143,10 +143,54 @@ class DataExporter: self.logger.error("无法获取月度均价数据") return False + # 读取港股通标记 + hk_inout_data = self.get_hk_inout() + + for item in monthly_data: + stock_name = item.get('stock_code')[3:] + if stock_name in hk_inout_data: + item['in_out'] = 1 + else: + item['in_out'] = 0 + # 导出结果 file_path = 'data/' + csv_file if csv_file else None csv_success = True if csv_file: csv_success = self.export_to_csv(monthly_data, file_path) - return csv_success \ No newline at end of file + return csv_success + + def get_hk_inout(self) -> Optional[List[Dict]]: + """ + 从conditionalselection表读取流通股本数据 + + Args: + table_name: 源数据表名 + + Returns: + List[Dict]: 查询结果数据集,失败返回None + """ + try: + with MySQLHelper(**self.db_config) as db: + # 查询流通股本数据 + data = db.execute_query(f""" + SELECT stock_code + FROM hk_stock_connect + WHERE in_out = '1' + ORDER BY stock_code + """) + + if not data: + self.logger.error(f"获取数据失败") + return None + + return [ + row['stock_code'] + for row in data + if row['stock_code'] + ] + + except Exception as e: + self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}") + return None \ No newline at end of file diff --git a/DataAnalysis/MarketDataCalculator.py b/DataAnalysis/MarketDataCalculator.py index 371d6e8..495458d 100644 --- a/DataAnalysis/MarketDataCalculator.py +++ b/DataAnalysis/MarketDataCalculator.py @@ -16,7 +16,8 @@ from typing import Optional, List, Dict, Union, Tuple import csv from typing import List, Dict, Optional from datetime import datetime -from base import LogHelper, MySQLHelper, Config +from base import LogHelper, MySQLHelper, ConfigInfo + class MarketDataCalculator: """ @@ -41,8 +42,8 @@ class MarketDataCalculator: """ self.db_config = db_config self.logger = LogHelper(logger_name=logger_name).setup() - self.month_ranges = Config.ConfigInfo.MONTH_RANGES - self.head_map = Config.ConfigInfo.HEADER_MAP + self.month_ranges = ConfigInfo.MONTH_RANGES + self.head_map = ConfigInfo.HEADER_MAP def create_monthly_avg_table(self, target_table: str = "monthly_close_avg") -> bool: """ @@ -69,6 +70,7 @@ class MarketDataCalculator: ym_2506 DECIMAL(20, 5) COMMENT '2025年06月', ym_2507 DECIMAL(20, 5) COMMENT '2025年07月', ym_2508 DECIMAL(20, 5) COMMENT '2025年08月', + ym_2509 DECIMAL(20, 5) COMMENT '2025年09月', avg_all DECIMAL(20, 5) COMMENT '月间均值', update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', UNIQUE KEY uk_stock_code (stock_code) @@ -82,7 +84,7 @@ class MarketDataCalculator: return False def calculate_and_save_monthly_avg(self, - source_table: str = "stock_quotes", + stock_code: str = "code", target_table: str = "monthly_close_avg") -> bool: """ 计算并保存2024年10月至2025年8月的月均流通市值 @@ -100,79 +102,85 @@ class MarketDataCalculator: return False with MySQLHelper(**self.db_config) as db: + # 获取所有股票代码和名称 - stock_info = db.execute_query( - f"SELECT DISTINCT stock_code, stock_name FROM {source_table}" - ) - - if not stock_info: - self.logger.error("没有获取到股票基本信息") - return False - - # 为每只股票计算各月均值 - for stock in stock_info: - stock_code = stock['stock_code'] - stock_name = stock['stock_name'] - - monthly_data = {'stock_code': stock_code, 'stock_name': stock_name} - - # 计算每个月的均值 - for month_col, (start_date, end_date) in self.month_ranges.items(): - sql = """ - SELECT AVG(close_price * float_share) as avg_close - FROM {} - WHERE stock_code = %s - AND trade_date BETWEEN %s AND %s - AND close_price IS NOT NULL - AND float_share IS NOT NULL - """.format(source_table) - - result = db.execute_query(sql, (stock_code, start_date, end_date)) - # 保存小数点后两位,以亿为单位 - # monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None - monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None - - # 提取所有以 'ym_' 开头的键的值 - ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')] - valid_ym_values = [value for value in ym_values if value is not None] - - # 计算全部月的均值 - if valid_ym_values: - average = sum(valid_ym_values) / len(valid_ym_values) - monthly_data['avg_all'] = average - self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}") - else: - monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错 - self.logger.warning(f"股票 {stock_code} 没有有效的月度数据") - - # 插入或更新数据 - upsert_sql = f""" - INSERT INTO {target_table} ( - stock_code, stock_name, - ym_2501, ym_2502, ym_2503, ym_2504, - ym_2505, ym_2506,ym_2507, ym_2508, - avg_all - ) VALUES ( - %(stock_code)s, %(stock_name)s, - %(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s, - %(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s, - %(avg_all)s - ) - ON DUPLICATE KEY UPDATE - stock_name = VALUES(stock_name), - ym_2501 = VALUES(ym_2501), - ym_2502 = VALUES(ym_2502), - ym_2503 = VALUES(ym_2503), - ym_2504 = VALUES(ym_2504), - ym_2505 = VALUES(ym_2505), - ym_2506 = VALUES(ym_2506), - ym_2507 = VALUES(ym_2507), - ym_2508 = VALUES(ym_2508), - avg_all = VALUES(avg_all), - update_time = CURRENT_TIMESTAMP + sql = """ + SELECT stock_code, stock_name + FROM stock_filter + WHERE stock_code = %s """ - db.execute_update(upsert_sql, monthly_data) + + # stock_filter表格可以当成标准表 + stock_info = db.execute_query(sql,(stock_code)) + if len(stock_info) == 0: + return + + stock_code = stock_info[0]['stock_code'] + stock_name = stock_info[0]['stock_name'] + + monthly_data = {'stock_code': stock_code, 'stock_name': stock_name} + + # 计算每个月的均值 + source_table = 'hk_' + stock_code[3:] + for month_col, (start_date, end_date) in self.month_ranges.items(): + sql = """ + SELECT AVG(close_price * float_share) as avg_close + FROM {} + WHERE stock_code = %s + AND trade_date BETWEEN %s AND %s + AND close_price IS NOT NULL + AND float_share IS NOT NULL + """.format(source_table) + + result = db.execute_query(sql, (stock_code, start_date, end_date)) + # 保存小数点后两位,以亿为单位 + # monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None + monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None + + # 提取所有以 'ym_' 开头的键的值 + ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')] + valid_ym_values = [value for value in ym_values if value is not None] + + # 计算全部月的均值 + if valid_ym_values: + average = sum(valid_ym_values) / len(valid_ym_values) + monthly_data['avg_all'] = average + self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}") + else: + monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错 + self.logger.warning(f"股票 {stock_code} 没有有效的月度数据") + + # 插入或更新数据 + upsert_sql = f""" + INSERT INTO {target_table} ( + stock_code, stock_name, + ym_2501, ym_2502, ym_2503, ym_2504, + ym_2505, ym_2506,ym_2507, ym_2508, + ym_2509, + avg_all + ) VALUES ( + %(stock_code)s, %(stock_name)s, + %(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s, + %(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s, + %(ym_2509)s, + %(avg_all)s + ) + ON DUPLICATE KEY UPDATE + stock_name = VALUES(stock_name), + ym_2501 = VALUES(ym_2501), + ym_2502 = VALUES(ym_2502), + ym_2503 = VALUES(ym_2503), + ym_2504 = VALUES(ym_2504), + ym_2505 = VALUES(ym_2505), + ym_2506 = VALUES(ym_2506), + ym_2507 = VALUES(ym_2507), + ym_2508 = VALUES(ym_2508), + ym_2509 = VALUES(ym_2509), + avg_all = VALUES(avg_all), + update_time = CURRENT_TIMESTAMP + """ + db.execute_update(upsert_sql, monthly_data) # self.logger.info("月度均值计算和保存完成") return True except Exception as e: @@ -435,4 +443,6 @@ class MarketDataCalculator: if csv_file: csv_success = self.export_to_csv(monthly_data, file_path) - return csv_success \ No newline at end of file + return csv_success + + \ No newline at end of file diff --git a/UpdateFutuData/KLineUpdater.py b/UpdateFutuData/KLineUpdater.py index 94ff100..30932c6 100644 --- a/UpdateFutuData/KLineUpdater.py +++ b/UpdateFutuData/KLineUpdater.py @@ -193,7 +193,7 @@ class KLineUpdater: # 预处理数据 processed_data = self.preprocess_quote_data(quote_data, float_share) if not processed_data: - self.logger.error("没有有效数据需要保存") + self.logger.error(f"没有有效数据需要保存,表:{table_name}") return False # 动态生成SQL插入语句 @@ -252,7 +252,7 @@ class KLineUpdater: self.logger.info(f"创建了新表: {table_name}") affected_rows = db.execute_many(insert_sql, processed_data) - self.logger.info(f"成功插入/更新 {affected_rows} 条行情记录到表 {table_name}") + # self.logger.info(f"成功插入/更新 {affected_rows} 条行情记录到表 {table_name}") return True except Exception as e: self.logger.error(f"保存行情数据到表 {table_name} 失败: {str(e)}") diff --git a/base/Config.py b/base/Config.py index 4b1be2b..75f59a7 100644 --- a/base/Config.py +++ b/base/Config.py @@ -23,7 +23,9 @@ class ConfigInfo: 'ym_2506': '2025年06月', 'ym_2507': '2025年07月', 'ym_2508': '2025年08月', - 'avg_all': '月度均值' + 'ym_2509': '2025年09月', + 'avg_all': '月度均值', + 'in_out':'是否在港股通' } # 月份范围配置 @@ -35,6 +37,7 @@ class ConfigInfo: 'ym_2505': ('2025-05-01', '2025-05-31'), 'ym_2506': ('2025-06-01', '2025-06-30'), 'ym_2507': ('2025-07-01', '2025-07-31'), - 'ym_2508': ('2025-08-01', '2025-08-31') + 'ym_2508': ('2025-08-01', '2025-08-31'), + 'ym_2509': ('2025-09-01', '2025-09-30') } diff --git a/base/MySQLHelper.py b/base/MySQLHelper.py index d8f731d..95d9fd1 100644 --- a/base/MySQLHelper.py +++ b/base/MySQLHelper.py @@ -9,7 +9,7 @@ import pymysql from pymysql import Error from typing import List, Dict, Union, Optional, Tuple from contextlib import contextmanager -from base.LogHelper import LogHelper +from .LogHelper import LogHelper # 基本用法(自动创建日期日志+控制台输出) logger = LogHelper(logger_name = 'database').setup() @@ -84,7 +84,9 @@ class MySQLHelper: """ try: self.cursor.execute(sql, params) - return self.cursor.fetchall() + result = self.cursor.fetchall() + + return result except Error as e: logger.error(f"查询执行失败: {e}") return [] @@ -233,4 +235,4 @@ class MySQLHelper: return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close() \ No newline at end of file + self.close() diff --git a/base/StockDataImporter.py b/base/StockDataImporter.py new file mode 100644 index 0000000..71808c4 --- /dev/null +++ b/base/StockDataImporter.py @@ -0,0 +1,499 @@ +""" + 存储上海证券交易股票列表数据 + + 不确定其数据爬取规则,防止 IP 被封 + 暂时使用该方案,获取股票列表数据 + —— 下载excel,收到导入到数据库 +""" + +import pandas as pd +import os +import sys +import csv +import chardet # 用于检测文件编码 +from pathlib import Path +from datetime import datetime + +from .MySQLHelper import MySQLHelper +from .LogHelper import LogHelper + +logger = LogHelper(logger_name = 'execelImport').setup() + +class StockDataImporter: + """股票数据导入工具(支持CSV)""" + + def __init__(self, db_config: dict, column_mapping: dict, data_dir: Path): + self.db_config = db_config + self.column_mapping = column_mapping + self.data_dir = data_dir + self.df = None + self.csv_file = None + self.encoding = 'utf-8' # 默认编码 + self.delimiter = ',' # 默认分隔符 + self.upload_filename = None # 上传文件名 + # 更新 检讨标志 + + def setUploadfile(self, filename: str): + """设置需要上传的文件名""" + self.upload_filename = filename + logger.info(f"设置上传文件名为: {filename}") + + def find_csv_file(self) -> Path: + """在data文件夹中查找CSV文件""" + # 使用设置的upload_filename或默认文件名 + filename = self.upload_filename if self.upload_filename else "GPLIST.csv" + + # 查找所有匹配的文件 + csv_files = list(self.data_dir.glob(filename)) + + if not csv_files: + logger.error(f"在 {self.data_dir} 中没有找到文件: {filename}") + return None + + # 如果有多个文件,选择最新的 + if len(csv_files) > 1: + csv_files.sort(key=os.path.getmtime, reverse=True) + logger.info(f"找到多个文件,选择最新的: {csv_files[0].name}") + + return csv_files[0] + + def validate_file(self, file_path: Path) -> bool: + """验证CSV文件是否有效""" + try: + if not file_path.exists(): + logger.error(f"CSV文件不存在: {file_path}") + return False + + file_size = file_path.stat().st_size + if file_size == 0: + logger.error(f"CSV文件为空: {file_path}") + return False + + return True + except Exception as e: + logger.error(f"文件验证失败: {e}") + return False + + def detect_file_encoding(self, file_path: Path) -> str: + """检测文件编码""" + try: + # 读取文件开头部分进行编码检测 + with open(file_path, 'rb') as f: + raw_data = f.read(10000) # 读取前10KB + + # 使用chardet检测编码 + result = chardet.detect(raw_data) + encoding = result['encoding'] + confidence = result['confidence'] + + # 常见编码替代 + encoding_map = { + 'GB2312': 'GBK', + 'gb2312': 'GBK', + 'ISO-8859-1': 'latin1', + 'ascii': 'utf-8' + } + + # 应用映射 + encoding = encoding_map.get(encoding, encoding) + + logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})") + return encoding or 'utf-8' + except Exception as e: + logger.error(f"编码检测失败: {e}, 使用默认UTF-8") + return 'utf-8' + + def detect_csv_delimiter(self, file_path: Path) -> str: + """自动检测CSV分隔符""" + try: + # 使用检测到的编码打开文件 + with open(file_path, 'r', encoding=self.encoding) as f: + # 读取前5行 + lines = [f.readline() for _ in range(5) if f.readline()] + + # 尝试常见分隔符 + delimiters = [',', '\t', ';', '|'] + delimiter_counts = {} + + for delim in delimiters: + count = 0 + for line in lines: + count += line.count(delim) + delimiter_counts[delim] = count + + # 选择出现次数最多的分隔符 + best_delim = max(delimiter_counts, key=delimiter_counts.get) + + # 如果没有任何分隔符,则使用逗号 + if delimiter_counts[best_delim] == 0: + logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符") + return ',' + + logger.info(f"检测到分隔符: {repr(best_delim)}") + return best_delim + except Exception as e: + logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符") + return ',' + + def read_csv_data(self, file_path: Path) -> bool: + """从CSV文件读取数据""" + try: + # 1. 检测文件编码 + self.encoding = self.detect_file_encoding(file_path) + + # 2. 检测分隔符 + self.delimiter = self.detect_csv_delimiter(file_path) + + # 3. 读取CSV文件 + logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件") + + self.df = pd.read_csv( + file_path, + delimiter=self.delimiter, + dtype=str, + encoding=self.encoding, + on_bad_lines='warn', + quoting=csv.QUOTE_MINIMAL, + engine='python' # 更健壮的引擎 + ) + + # 检查是否读取到数据 + if self.df.empty: + logger.error("CSV文件没有包含有效数据") + return False + + # 重命名列 + self.df = self.df.rename(columns=self.column_mapping) + + # 移除可能存在的空行 + self.df = self.df.dropna(how='all') + + logger.info(f"成功读取CSV数据,共 {len(self.df)} 条记录") + return True + except UnicodeDecodeError: + # 尝试其他编码 + encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16'] + for enc in encodings_to_try: + try: + logger.warning(f"尝试使用 {enc} 编码读取文件") + self.df = pd.read_csv( + file_path, + delimiter=self.delimiter, + dtype=str, + encoding=enc + ) + self.encoding = enc + logger.info(f"成功使用 {enc} 编码读取文件") + return True + except: + continue + + logger.error("所有编码尝试均失败") + return False + except PermissionError: + logger.error(f"文件被占用,请关闭后重试: {file_path}") + return False + except Exception as e: + logger.error(f"读取CSV文件失败: {e}") + return False + + def clean_stock_data(self) -> bool: + """清洗股票数据(基本清洗,主要处理股票代码格式)""" + try: + # 验证股票代码格式(如果存在stock_code列) + if 'stock_code' in self.df.columns: + invalid_codes = self.df[~self.df['stock_code'].astype(str).str.match(r'^\d{6}$')] + if not invalid_codes.empty: + logger.warning(f"发现 {len(invalid_codes)} 条无效的股票代码") + logger.debug(f"无效代码示例: {invalid_codes['stock_code'].head().tolist()}") + + logger.info("数据清洗完成") + return True + except Exception as e: + logger.error(f"数据清洗失败: {e}") + return False + + def create_stocks_table(self, db: MySQLHelper) -> bool: + """创建股票信息表(包含股票代码、中文名称、英文名称、进出标志和时间戳)""" + # 定义列类型映射 + column_type_mapping = { + 'stock_code': 'VARCHAR(6) PRIMARY KEY', + 'stock_name_chn': 'VARCHAR(50) NULL', + 'stock_name_en': 'VARCHAR(150)', + } + + # 构建列定义SQL + column_definitions = [] + for column_name in self.column_mapping.values(): + if column_name in column_type_mapping: + column_definitions.append(f"{column_name} {column_type_mapping[column_name]}") + else: + # 对于未知列,使用VARCHAR(255) + column_definitions.append(f"{column_name} VARCHAR(255)") + + # 添加进出标志列 + column_definitions.append("in_out TINYINT(1) DEFAULT 0 COMMENT '进出标志'") + + # 添加时间戳列 + column_definitions.append("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'") + column_definitions.append("updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间'") + + # 构建完整的CREATE TABLE SQL + columns_sql = ",\n ".join(column_definitions) + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS hk_stock_connect ( + {columns_sql} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票信息表'; + """ + + try: + db.execute_update(create_table_sql) + logger.info("股票信息表创建成功") + return True + except Exception as e: + logger.error(f"创建表失败: {e}") + return False + + def insert_data_to_db(self, db: MySQLHelper) -> bool: + """将数据插入数据库(处理映射的列和进出标志)""" + if self.df is None or self.df.empty: + logger.error("没有有效数据可插入") + return False + + # 获取所有映射的列名 + mapped_columns = list(self.column_mapping.values()) + + # 构建INSERT SQL语句(包含in_out列,设置为1) + columns_sql = ", ".join(mapped_columns + ['in_out']) + placeholders = ", ".join(["%s"] * len(mapped_columns) + ["1"]) # in_out固定为1 + + # 构建ON DUPLICATE KEY UPDATE部分(主键不更新,但更新in_out字段) + update_clauses = [] + for column in mapped_columns: + if column != 'stock_code': # 主键不更新 + update_clauses.append(f"{column} = VALUES({column})") + # 添加in_out字段更新,确保在重复时也设置为1 + update_clauses.append("in_out = VALUES(in_out)") + update_sql = ", ".join(update_clauses) + + insert_sql = f""" + INSERT INTO hk_stock_connect ( + {columns_sql} + ) VALUES ( + {placeholders} + ) + ON DUPLICATE KEY UPDATE + {update_sql} + """ + + # 准备参数列表(只包含映射的列,in_out由SQL固定为1) + params_list = [] + for _, row in self.df.iterrows(): + params = [] + for column in mapped_columns: + # 处理可能的NaN值 + value = row[column] if column in row and pd.notna(row[column]) else None + params.append(value) + + params_list.append(tuple(params)) + + # 批量执行插入 + try: + total_rows = len(params_list) + if total_rows == 0: + logger.error("没有有效数据可插入") + return False + + batch_size = 1000 # 每批插入1000条记录 + + logger.info(f"开始插入数据,共 {total_rows} 条记录") + + # 分批插入,避免大事务问题 + for i in range(0, total_rows, batch_size): + batch_params = params_list[i:i+batch_size] + affected_rows = db.execute_many(insert_sql, batch_params) + logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录") + + logger.info(f"成功插入/更新 {total_rows} 条记录") + return True + except Exception as e: + logger.error(f"插入数据失败: {e}") + # 记录前5个参数以帮助调试 + if params_list: + logger.debug(f"前5个参数示例: {params_list[:5]}") + return False + + def setInOut(self, db: MySQLHelper) -> bool: + """设置进出标志为0,且不更新updated_at字段""" + try: + # # 首先检查表是否存在 + # table_check_sql = """ + # SELECT COUNT(*) AS table_exists + # FROM information_schema.tables + # WHERE table_schema = DATABASE() + # AND table_name = 'hk_stock_connect' + # """ + # result = db.execute_query(table_check_sql) + + # if not result or result[0]['table_exists'] == 0: + # logger.warning("表 'hk_stock_connect' 不存在,跳过设置进出标志操作") + # return True + + # 表存在,执行更新操作 + update_sql = "UPDATE hk_stock_connect SET in_out = 1, updated_at = updated_at" + affected_rows = db.execute_update(update_sql) + logger.info(f"成功设置 {affected_rows} 条记录的进出标志为0") + return True + except Exception as e: + logger.error(f"设置进出标志失败: {e}") + return False + + def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool: + """验证数据库中的数据""" + try: + # 检查记录总数 + count_sql = "SELECT COUNT(*) AS total FROM hk_stock_connect" + result = db.execute_query(count_sql) + db_count = result[0]['total'] if result else 0 + logger.info(f"数据库中共有 {db_count} 条记录") + + # 获取映射的列名用于显示 + mapped_columns = list(self.column_mapping.values()) + + # 构建查询列(使用所有映射的列和in_out字段) + select_columns = mapped_columns + ['in_out'] if mapped_columns else ["*"] + columns_sql = ", ".join(select_columns) + + # 随机抽样检查 + sample_sql = f""" + SELECT {columns_sql} + FROM hk_stock_connect + ORDER BY RAND() + LIMIT {sample_size} + """ + samples = db.execute_query(sample_sql) + + logger.info("\n随机抽样记录:") + for idx, sample in enumerate(samples, 1): + # 动态构建日志消息,显示所有映射的列和in_out字段 + sample_info = [] + for column in select_columns: + if column in sample and sample[column] is not None: + sample_info.append(f"{column}: {sample[column]}") + + logger.info(f"{idx}. {' | '.join(sample_info) if sample_info else 'No data to display'}") + + return True + except Exception as e: + logger.error(f"数据验证失败: {e}") + return False + + def run_import(self) -> bool: + """执行完整的导入流程""" + logger.info(f"开始导入股票数据,数据目录: {self.data_dir}") + start_time = datetime.now() + + # 1. 查找CSV文件 + csv_file = self.find_csv_file() + if not csv_file: + return False + + # 2. 验证文件 + if not self.validate_file(csv_file): + return False + + # 3. 读取CSV数据 + if not self.read_csv_data(csv_file): + return False + + # 4. 清洗数据 + if not self.clean_stock_data(): + return False + + # 显示前5条数据(动态显示可用列) + logger.info("\n前5条股票数据:") + for i, row in self.df.head().iterrows(): + # 动态构建显示信息 + display_info = [] + if 'stock_code' in row: + display_info.append(f"代码: {row['stock_code']}") + if 'stock_name_chn' in row: + display_info.append(f"名称: {row['stock_name_chn']}") + if 'stock_name_en' in row: + display_info.append(f"英文: {row['stock_name_en']}") + + logger.info(f"{i+1}. {' | '.join(display_info)}") + + # 5. 连接数据库并导入 + try: + with MySQLHelper(**self.db_config) as db: + # 5.1 创建表 + if not self.create_stocks_table(db): + return False + + # 更新检讨结果 + self.setInOut(db) + + # 5.2 插入数据 + if not self.insert_data_to_db(db): + return False + + # 5.3 验证数据 + if not self.verify_data_in_db(db): + return False + except Exception as e: + logger.error(f"数据库操作异常: {e}") + return False + + # 计算执行时间 + duration = datetime.now() - start_time + logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}秒") + return True + +if __name__ == "__main__": + + # 数据库配置 + db_config = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'hk_kline_1d' + } + + # 列映射配置 + COLUMN_MAPPING = { + '证券代码': 'stock_code', + '中文简称': 'stock_name_chn', + '英文简称': 'stock_name_en', + } + + # 获取当前脚本所在目录 + current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() + + # 设置数据目录 + DATA_DIR = current_dir / "data" + + # 确保data目录存在 + DATA_DIR.mkdir(exist_ok=True, parents=True) + + # 安装依赖 (如果chardet未安装) + try: + import chardet + except ImportError: + logger.info("安装chardet库以支持编码检测...") + import subprocess + subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"]) + import chardet + + # 创建导入器并执行导入(使用新的参数顺序) + importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR) + + + + # 2.2w设置上传文件名(可以注释掉使用默认文件名) + importer.setUploadfile("港股通标的证券名单.csv") + + if importer.run_import(): + logger.info("股票数据导入成功!") + else: + logger.error("股票数据导入失败,请检查日志了解详情") diff --git a/base/__init__.py b/base/__init__.py index a75bcb9..b1393bb 100644 --- a/base/__init__.py +++ b/base/__init__.py @@ -1,3 +1,4 @@ from .MySQLHelper import MySQLHelper from .LogHelper import LogHelper -from .Config import ConfigInfo \ No newline at end of file +from .Config import ConfigInfo +from .StockDataImporter import StockDataImporter \ No newline at end of file diff --git a/config/HK_futu.txt b/config/HK_futu.txt index c4b5a77..41b1caa 100644 --- a/config/HK_futu.txt +++ b/config/HK_futu.txt @@ -895,4 +895,8 @@ HK.08218 HK.06960 HK.02936 HK.03858 -HK.08132 \ No newline at end of file +HK.08132 +HK.02938 +HK.02580 +HK.02941 +HK.02935 \ No newline at end of file diff --git a/config/Removecode.txt b/config/Removecode.txt new file mode 100644 index 0000000..35f5a04 --- /dev/null +++ b/config/Removecode.txt @@ -0,0 +1,2 @@ +HK.04335 +HK.02292 \ No newline at end of file diff --git a/config/kevin_futu.txt b/config/kevin_futu.txt index 59d3ed3..67a5ad2 100644 --- a/config/kevin_futu.txt +++ b/config/kevin_futu.txt @@ -686,7 +686,6 @@ HK.02179 HK.00442 HK.01959 HK.01985 -HK.02992 HK.00314 HK.01459 HK.01082 diff --git a/config/missing_codes.txt b/config/missing_codes.txt index e69de29..b8e80fd 100644 --- a/config/missing_codes.txt +++ b/config/missing_codes.txt @@ -0,0 +1 @@ +HK.02938 \ No newline at end of file diff --git a/config/备注.txt b/config/备注.txt index 85e831e..00008cf 100644 --- a/config/备注.txt +++ b/config/备注.txt @@ -1,5 +1,5 @@ kevin_futu: - 已使用 1000 行情 + 已使用 1000 行情 -> 移除 HK.02992 hang_futu: 已使用 703 行情 HK牛仔: diff --git a/main_gui.py b/main_gui.py index c7e8f45..e096527 100644 --- a/main_gui.py +++ b/main_gui.py @@ -102,6 +102,9 @@ class MainWindow(QMainWindow): # 创建功能按钮组 self.create_button_group(main_layout) + # 创建数据导入按钮组 + self.create_import_button_group(main_layout) + # 创建进度条 self.progress_bar = QProgressBar() # self.progress_bar.setVisible(False) @@ -307,6 +310,20 @@ class MainWindow(QMainWindow): button_group.setLayout(button_layout) layout.addWidget(button_group) + + def create_import_button_group(self, layout): + """创建数据导入按钮组""" + import_group = QGroupBox("数据导入") + import_layout = QHBoxLayout() + + # 导入数据按钮 + self.btn_import = QPushButton('导入股票数据') + self.btn_import.clicked.connect(self.on_import_clicked) + self.btn_import.setToolTip('从CSV文件导入股票数据到数据库') + import_layout.addWidget(self.btn_import) + + import_group.setLayout(import_layout) + layout.addWidget(import_group) def create_log_area(self, layout): """创建日志显示区域""" @@ -341,6 +358,8 @@ class MainWindow(QMainWindow): self.btn_export.setEnabled(enabled) self.btn_calculate.setEnabled(enabled) self.btn_check.setEnabled(enabled) + self.btn_import.setEnabled(enabled) + self.btn_float_share.setEnabled(enabled) def on_update_clicked(self): """更新数据按钮点击事件""" @@ -401,6 +420,17 @@ class MainWindow(QMainWindow): worker.finished_signal.connect(self.on_task_finished) worker.start() self.worker_threads.append(worker) + + def on_import_clicked(self): + """导入数据按钮点击事件""" + self.log_message("开始导入股票数据...") + self.set_buttons_enabled(False) + + worker = WorkerThread(self.import_stock_data) + worker.log_signal.connect(self.log_message) + worker.finished_signal.connect(self.on_task_finished) + worker.start() + self.worker_threads.append(worker) def on_task_finished(self, success, message): """任务完成回调""" @@ -486,8 +516,9 @@ class MainWindow(QMainWindow): # 移除人民币交易的股票:股票名称最后一个字符为R,误删除的从配置文件读回来 reserved_codes = calculator.read_stock_codes_list(Path.cwd() / "config" / "Reservedcode.txt") + remove_codes = calculator.read_stock_codes_list(Path.cwd() / "config" / "Removecode.txt") market_data_ll = calculator.get_stock_codes() # 使用按照价格和流通股数量筛选的那个表格 - market_data = market_data_ll + reserved_codes + market_data = market_data_ll + reserved_codes - remove_codes # 根据统计时间进行命名 target_table_name = 'hk_monthly_avg_' + datetime.now().strftime("%Y%m%d") @@ -496,11 +527,10 @@ class MainWindow(QMainWindow): # 使用tqdm创建进度条 for code in tqdm(market_data, desc="处理股票数据", unit="支"): - tablename = 'hk_' + code[3:] # 计算并保存月度均值 calculator.calculate_and_save_monthly_avg( - source_table=tablename, - target_table=target_table_name + stock_code =code, + target_table = target_table_name ) # self.log_message("月度平均计算完成") @@ -529,6 +559,54 @@ class MainWindow(QMainWindow): except Exception as e: self.log_message(f"数据检查失败: {str(e)}") raise + + def import_stock_data(self): + """导入股票数据任务""" + try: + # 导入必要的模块 + from base.StockDataImporter import StockDataImporter + from base.MySQLHelper import MySQLHelper + from pathlib import Path + + # 数据库配置 + db_config = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'hk_kline_1d' + } + + # 列映射配置 + COLUMN_MAPPING = { + '证券代码': 'stock_code', + '中文简称': 'stock_name_chn', + '英文简称': 'stock_name_en', + } + + # 设置数据目录 + current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() + DATA_DIR = current_dir / "data" + + # 确保data目录存在 + DATA_DIR.mkdir(exist_ok=True, parents=True) + + # 创建导入器 + importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR) + + # 设置上传文件名(使用默认文件名 "港股通标的证券名单.csv") + importer.setUploadfile("港股通标的证券名单.csv") + + # 执行导入 + if importer.run_import(): + self.log_message("股票数据导入成功!") + return True + else: + self.log_message("股票数据导入失败,请检查日志了解详情") + return False + + except Exception as e: + self.log_message(f"导入股票数据失败: {str(e)}") + raise def closeEvent(self, event): """窗口关闭事件"""