""" 存储上海证券交易股票列表数据 不确定其数据爬取规则,防止 IP 被封 暂时使用该方案,获取股票列表数据 —— 下载excel,收到导入到数据库 """ import pandas as pd import os import sys import csv import chardet # 用于检测文件编码 from pathlib import Path from datetime import datetime from .MySQLHelper import MySQLHelper from .LogHelper import LogHelper logger = LogHelper(logger_name = 'execelImport').setup() class StockDataImporter: """股票数据导入工具(支持CSV)""" def __init__(self, db_config: dict, column_mapping: dict, data_dir: Path): self.db_config = db_config self.column_mapping = column_mapping self.data_dir = data_dir self.df = None self.csv_file = None self.encoding = 'utf-8' # 默认编码 self.delimiter = ',' # 默认分隔符 self.upload_filename = None # 上传文件名 # 更新 检讨标志 def setUploadfile(self, filename: str): """设置需要上传的文件名""" self.upload_filename = filename logger.info(f"设置上传文件名为: {filename}") def find_csv_file(self) -> Path: """在data文件夹中查找CSV文件""" # 使用设置的upload_filename或默认文件名 filename = self.upload_filename if self.upload_filename else "GPLIST.csv" # 查找所有匹配的文件 csv_files = list(self.data_dir.glob(filename)) if not csv_files: logger.error(f"在 {self.data_dir} 中没有找到文件: {filename}") return None # 如果有多个文件,选择最新的 if len(csv_files) > 1: csv_files.sort(key=os.path.getmtime, reverse=True) logger.info(f"找到多个文件,选择最新的: {csv_files[0].name}") return csv_files[0] def validate_file(self, file_path: Path) -> bool: """验证CSV文件是否有效""" try: if not file_path.exists(): logger.error(f"CSV文件不存在: {file_path}") return False file_size = file_path.stat().st_size if file_size == 0: logger.error(f"CSV文件为空: {file_path}") return False return True except Exception as e: logger.error(f"文件验证失败: {e}") return False def detect_file_encoding(self, file_path: Path) -> str: """检测文件编码""" try: # 读取文件开头部分进行编码检测 with open(file_path, 'rb') as f: raw_data = f.read(10000) # 读取前10KB # 使用chardet检测编码 result = chardet.detect(raw_data) encoding = result['encoding'] confidence = result['confidence'] # 常见编码替代 encoding_map = { 'GB2312': 'GBK', 'gb2312': 'GBK', 'ISO-8859-1': 'latin1', 'ascii': 'utf-8' } # 应用映射 encoding = encoding_map.get(encoding, encoding) logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})") return encoding or 'utf-8' except Exception as e: logger.error(f"编码检测失败: {e}, 使用默认UTF-8") return 'utf-8' def detect_csv_delimiter(self, file_path: Path) -> str: """自动检测CSV分隔符""" try: # 使用检测到的编码打开文件 with open(file_path, 'r', encoding=self.encoding) as f: # 读取前5行 lines = [f.readline() for _ in range(5) if f.readline()] # 尝试常见分隔符 delimiters = [',', '\t', ';', '|'] delimiter_counts = {} for delim in delimiters: count = 0 for line in lines: count += line.count(delim) delimiter_counts[delim] = count # 选择出现次数最多的分隔符 best_delim = max(delimiter_counts, key=delimiter_counts.get) # 如果没有任何分隔符,则使用逗号 if delimiter_counts[best_delim] == 0: logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符") return ',' logger.info(f"检测到分隔符: {repr(best_delim)}") return best_delim except Exception as e: logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符") return ',' def read_csv_data(self, file_path: Path) -> bool: """从CSV文件读取数据""" try: # 1. 检测文件编码 self.encoding = self.detect_file_encoding(file_path) # 2. 检测分隔符 self.delimiter = self.detect_csv_delimiter(file_path) # 3. 读取CSV文件 logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件") self.df = pd.read_csv( file_path, delimiter=self.delimiter, dtype=str, encoding=self.encoding, on_bad_lines='warn', quoting=csv.QUOTE_MINIMAL, engine='python' # 更健壮的引擎 ) # 检查是否读取到数据 if self.df.empty: logger.error("CSV文件没有包含有效数据") return False # 重命名列 self.df = self.df.rename(columns=self.column_mapping) # 移除可能存在的空行 self.df = self.df.dropna(how='all') logger.info(f"成功读取CSV数据,共 {len(self.df)} 条记录") return True except UnicodeDecodeError: # 尝试其他编码 encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16'] for enc in encodings_to_try: try: logger.warning(f"尝试使用 {enc} 编码读取文件") self.df = pd.read_csv( file_path, delimiter=self.delimiter, dtype=str, encoding=enc ) self.encoding = enc logger.info(f"成功使用 {enc} 编码读取文件") return True except: continue logger.error("所有编码尝试均失败") return False except PermissionError: logger.error(f"文件被占用,请关闭后重试: {file_path}") return False except Exception as e: logger.error(f"读取CSV文件失败: {e}") return False def clean_stock_data(self) -> bool: """清洗股票数据(基本清洗,主要处理股票代码格式)""" try: # 验证股票代码格式(如果存在stock_code列) if 'stock_code' in self.df.columns: invalid_codes = self.df[~self.df['stock_code'].astype(str).str.match(r'^\d{6}$')] if not invalid_codes.empty: logger.warning(f"发现 {len(invalid_codes)} 条无效的股票代码") logger.debug(f"无效代码示例: {invalid_codes['stock_code'].head().tolist()}") logger.info("数据清洗完成") return True except Exception as e: logger.error(f"数据清洗失败: {e}") return False def create_stocks_table(self, db: MySQLHelper) -> bool: """创建股票信息表(包含股票代码、中文名称、英文名称、进出标志和时间戳)""" # 定义列类型映射 column_type_mapping = { 'stock_code': 'VARCHAR(6) PRIMARY KEY', 'stock_name_chn': 'VARCHAR(50) NULL', 'stock_name_en': 'VARCHAR(150)', } # 构建列定义SQL column_definitions = [] for column_name in self.column_mapping.values(): if column_name in column_type_mapping: column_definitions.append(f"{column_name} {column_type_mapping[column_name]}") else: # 对于未知列,使用VARCHAR(255) column_definitions.append(f"{column_name} VARCHAR(255)") # 添加进出标志列 column_definitions.append("in_out TINYINT(1) DEFAULT 0 COMMENT '进出标志'") # 添加时间戳列 column_definitions.append("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'") column_definitions.append("updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间'") # 构建完整的CREATE TABLE SQL columns_sql = ",\n ".join(column_definitions) create_table_sql = f""" CREATE TABLE IF NOT EXISTS hk_stock_connect ( {columns_sql} ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票信息表'; """ try: db.execute_update(create_table_sql) logger.info("股票信息表创建成功") return True except Exception as e: logger.error(f"创建表失败: {e}") return False def insert_data_to_db(self, db: MySQLHelper) -> bool: """将数据插入数据库(处理映射的列和进出标志)""" if self.df is None or self.df.empty: logger.error("没有有效数据可插入") return False # 获取所有映射的列名 mapped_columns = list(self.column_mapping.values()) # 构建INSERT SQL语句(包含in_out列,设置为1) columns_sql = ", ".join(mapped_columns + ['in_out']) placeholders = ", ".join(["%s"] * len(mapped_columns) + ["1"]) # in_out固定为1 # 构建ON DUPLICATE KEY UPDATE部分(主键不更新,但更新in_out字段) update_clauses = [] for column in mapped_columns: if column != 'stock_code': # 主键不更新 update_clauses.append(f"{column} = VALUES({column})") # 添加in_out字段更新,确保在重复时也设置为1 update_clauses.append("in_out = VALUES(in_out)") update_sql = ", ".join(update_clauses) insert_sql = f""" INSERT INTO hk_stock_connect ( {columns_sql} ) VALUES ( {placeholders} ) ON DUPLICATE KEY UPDATE {update_sql} """ # 准备参数列表(只包含映射的列,in_out由SQL固定为1) params_list = [] for _, row in self.df.iterrows(): params = [] for column in mapped_columns: # 处理可能的NaN值 value = row[column] if column in row and pd.notna(row[column]) else None params.append(value) params_list.append(tuple(params)) # 批量执行插入 try: total_rows = len(params_list) if total_rows == 0: logger.error("没有有效数据可插入") return False batch_size = 1000 # 每批插入1000条记录 logger.info(f"开始插入数据,共 {total_rows} 条记录") # 分批插入,避免大事务问题 for i in range(0, total_rows, batch_size): batch_params = params_list[i:i+batch_size] affected_rows = db.execute_many(insert_sql, batch_params) logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录") logger.info(f"成功插入/更新 {total_rows} 条记录") return True except Exception as e: logger.error(f"插入数据失败: {e}") # 记录前5个参数以帮助调试 if params_list: logger.debug(f"前5个参数示例: {params_list[:5]}") return False def setInOut(self, db: MySQLHelper) -> bool: """设置进出标志为0,且不更新updated_at字段""" try: # # 首先检查表是否存在 # table_check_sql = """ # SELECT COUNT(*) AS table_exists # FROM information_schema.tables # WHERE table_schema = DATABASE() # AND table_name = 'hk_stock_connect' # """ # result = db.execute_query(table_check_sql) # if not result or result[0]['table_exists'] == 0: # logger.warning("表 'hk_stock_connect' 不存在,跳过设置进出标志操作") # return True # 表存在,执行更新操作 update_sql = "UPDATE hk_stock_connect SET in_out = 1, updated_at = updated_at" affected_rows = db.execute_update(update_sql) logger.info(f"成功设置 {affected_rows} 条记录的进出标志为0") return True except Exception as e: logger.error(f"设置进出标志失败: {e}") return False def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool: """验证数据库中的数据""" try: # 检查记录总数 count_sql = "SELECT COUNT(*) AS total FROM hk_stock_connect" result = db.execute_query(count_sql) db_count = result[0]['total'] if result else 0 logger.info(f"数据库中共有 {db_count} 条记录") # 获取映射的列名用于显示 mapped_columns = list(self.column_mapping.values()) # 构建查询列(使用所有映射的列和in_out字段) select_columns = mapped_columns + ['in_out'] if mapped_columns else ["*"] columns_sql = ", ".join(select_columns) # 随机抽样检查 sample_sql = f""" SELECT {columns_sql} FROM hk_stock_connect ORDER BY RAND() LIMIT {sample_size} """ samples = db.execute_query(sample_sql) logger.info("\n随机抽样记录:") for idx, sample in enumerate(samples, 1): # 动态构建日志消息,显示所有映射的列和in_out字段 sample_info = [] for column in select_columns: if column in sample and sample[column] is not None: sample_info.append(f"{column}: {sample[column]}") logger.info(f"{idx}. {' | '.join(sample_info) if sample_info else 'No data to display'}") return True except Exception as e: logger.error(f"数据验证失败: {e}") return False def run_import(self) -> bool: """执行完整的导入流程""" logger.info(f"开始导入股票数据,数据目录: {self.data_dir}") start_time = datetime.now() # 1. 查找CSV文件 csv_file = self.find_csv_file() if not csv_file: return False # 2. 验证文件 if not self.validate_file(csv_file): return False # 3. 读取CSV数据 if not self.read_csv_data(csv_file): return False # 4. 清洗数据 if not self.clean_stock_data(): return False # 显示前5条数据(动态显示可用列) logger.info("\n前5条股票数据:") for i, row in self.df.head().iterrows(): # 动态构建显示信息 display_info = [] if 'stock_code' in row: display_info.append(f"代码: {row['stock_code']}") if 'stock_name_chn' in row: display_info.append(f"名称: {row['stock_name_chn']}") if 'stock_name_en' in row: display_info.append(f"英文: {row['stock_name_en']}") logger.info(f"{i+1}. {' | '.join(display_info)}") # 5. 连接数据库并导入 try: with MySQLHelper(**self.db_config) as db: # 5.1 创建表 if not self.create_stocks_table(db): return False # 更新检讨结果 self.setInOut(db) # 5.2 插入数据 if not self.insert_data_to_db(db): return False # 5.3 验证数据 if not self.verify_data_in_db(db): return False except Exception as e: logger.error(f"数据库操作异常: {e}") return False # 计算执行时间 duration = datetime.now() - start_time logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}秒") return True if __name__ == "__main__": # 数据库配置 db_config = { 'host': 'localhost', 'user': 'root', 'password': 'bzskmysql', 'database': 'hk_kline_1d' } # 列映射配置 COLUMN_MAPPING = { '证券代码': 'stock_code', '中文简称': 'stock_name_chn', '英文简称': 'stock_name_en', } # 获取当前脚本所在目录 current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() # 设置数据目录 DATA_DIR = current_dir / "data" # 确保data目录存在 DATA_DIR.mkdir(exist_ok=True, parents=True) # 安装依赖 (如果chardet未安装) try: import chardet except ImportError: logger.info("安装chardet库以支持编码检测...") import subprocess subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"]) import chardet # 创建导入器并执行导入(使用新的参数顺序) importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR) # 2.2w设置上传文件名(可以注释掉使用默认文件名) importer.setUploadfile("港股通标的证券名单.csv") if importer.run_import(): logger.info("股票数据导入成功!") else: logger.error("股票数据导入失败,请检查日志了解详情")