""" 存储深圳交易所股票列表数据 不确定其数据爬取规则,防止 IP 被封 暂时使用该方案,获取股票列表数据 —— 下载excel,收到导入到数据库 """ from pathlib import Path from datetime import datetime from MySQLHelper import MySQLHelper from LogHelper import LogHelper import pandas as pd import os import sys import csv import chardet logger = LogHelper(logger_name = 'SZ_Import').setup() class StockDataImporter: """股票数据导入工具(支持新版CSV格式)""" # 新版CSV列名映射 COLUMN_MAPPING = { '板块': 'market_type', '公司全称': 'company_full_name', '英文名称': 'eng_name', '注册地址': 'registered_address', 'A股代码': 'a_stock_code', 'A股简称': 'a_stock_short_name', 'A股上市日期': 'a_listing_date', 'A股总股本': 'a_total_shares', 'A股流通股本': 'a_circulating_shares', 'B股代码': 'b_stock_code', 'B股 简 称': 'b_stock_short_name', 'B股上市日期': 'b_listing_date', 'B股总股本': 'b_total_shares', 'B股流通股本': 'b_circulating_shares', '地 区': 'region', '省 份': 'province', '城 市': 'city', '所属行业': 'industry', '公司网址': 'website', '未盈利': 'unprofitable', '具有表决权差异安排': 'voting_rights_difference', '具有协议控制架构': 'agreement_control_structure' } def __init__(self, data_dir: Path, db_config: dict): self.data_dir = data_dir self.db_config = db_config self.df = None self.csv_file = None self.encoding = 'utf-8' # 默认编码 self.delimiter = ',' # 默认分隔符 def find_csv_file(self) -> Path: """在data文件夹中查找CSV文件""" # 查找所有CSV文件 csv_files = list(self.data_dir.glob("A股列表.csv")) if not csv_files: logger.error(f"在 {self.data_dir} 中没有找到CSV文件") return None # 如果有多个文件,选择最新的 if len(csv_files) > 1: csv_files.sort(key=os.path.getmtime, reverse=True) logger.info(f"找到多个CSV文件,选择最新的: {csv_files[0].name}") return csv_files[0] def validate_file(self, file_path: Path) -> bool: """验证CSV文件是否有效""" try: if not file_path.exists(): logger.error(f"CSV文件不存在: {file_path}") return False file_size = file_path.stat().st_size if file_size == 0: logger.error(f"CSV文件为空: {file_path}") return False return True except Exception as e: logger.error(f"文件验证失败: {e}") return False def detect_file_encoding(self, file_path: Path) -> str: """检测文件编码""" try: # 读取文件开头部分进行编码检测 with open(file_path, 'rb') as f: raw_data = f.read(10000) # 读取前10KB # 使用chardet检测编码 result = chardet.detect(raw_data) encoding = result['encoding'] confidence = result['confidence'] # 常见编码替代 encoding_map = { 'GB2312': 'GBK', 'gb2312': 'GBK', 'ISO-8859-1': 'latin1', 'ascii': 'utf-8' } # 应用映射 encoding = encoding_map.get(encoding, encoding) logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})") return encoding or 'utf-8' except Exception as e: logger.error(f"编码检测失败: {e}, 使用默认UTF-8") return 'utf-8' def detect_csv_delimiter(self, file_path: Path) -> str: """自动检测CSV分隔符""" try: # 使用检测到的编码打开文件 with open(file_path, 'r', encoding=self.encoding) as f: # 读取前5行 lines = [f.readline() for _ in range(5) if f.readline()] # 尝试常见分隔符 delimiters = [',', '\t', ';', '|'] delimiter_counts = {} for delim in delimiters: count = 0 for line in lines: count += line.count(delim) delimiter_counts[delim] = count # 选择出现次数最多的分隔符 best_delim = max(delimiter_counts, key=delimiter_counts.get) # 如果没有任何分隔符,则使用逗号 if delimiter_counts[best_delim] == 0: logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符") return ',' logger.info(f"检测到分隔符: {repr(best_delim)}") return best_delim except Exception as e: logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符") return ',' def read_csv_data(self, file_path: Path) -> bool: """从CSV文件读取数据""" try: # 1. 检测文件编码 self.encoding = self.detect_file_encoding(file_path) # 2. 检测分隔符 self.delimiter = self.detect_csv_delimiter(file_path) # 3. 读取CSV文件 logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件") self.df = pd.read_csv( file_path, delimiter=self.delimiter, dtype=str, encoding=self.encoding, on_bad_lines='warn', quoting=csv.QUOTE_MINIMAL, engine='python' # 更健壮的引擎 ) # 检查是否读取到数据 if self.df.empty: logger.error("CSV文件没有包含有效数据") return False # 重命名列 self.df = self.df.rename(columns=self.COLUMN_MAPPING) # 移除可能存在的空行 self.df = self.df.dropna(how='all') logger.info(f"成功读取CSV数据,共 {len(self.df)} 条记录") return True except UnicodeDecodeError: # 尝试其他编码 encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16'] for enc in encodings_to_try: try: logger.warning(f"尝试使用 {enc} 编码读取文件") self.df = pd.read_csv( file_path, delimiter=self.delimiter, dtype=str, encoding=enc ) self.encoding = enc logger.info(f"成功使用 {enc} 编码读取文件") return True except: continue logger.error("所有编码尝试均失败") return False except PermissionError: logger.error(f"文件被占用,请关闭后重试: {file_path}") return False except Exception as e: logger.error(f"读取CSV文件失败: {e}") return False def clean_stock_data(self) -> bool: """清洗股票数据(修复了website字段的NaN处理问题)""" try: # 处理数字字段中的逗号 numeric_columns = [ 'a_total_shares', 'a_circulating_shares', 'b_total_shares', 'b_circulating_shares' ] for col in numeric_columns: if col in self.df.columns: # 填充NaN为空字符串 self.df[col] = self.df[col].fillna('') # 转换为字符串 self.df[col] = self.df[col].astype(str) # 移除逗号和空格 self.df[col] = self.df[col].str.replace(',', '').str.replace(' ', '') # 格式化日期字段 date_columns = ['a_listing_date', 'b_listing_date'] for col in date_columns: if col in self.df.columns: # 填充NaN为空字符串 self.df[col] = self.df[col].fillna('') # 转换为datetime,无效日期转为NaT self.df[col] = pd.to_datetime( self.df[col], errors='coerce' ).dt.strftime('%Y-%m-%d') # 将NaT转换为空字符串 self.df[col] = self.df[col].replace('NaT', '') # 处理布尔字段 bool_columns = ['unprofitable', 'voting_rights_difference', 'agreement_control_structure'] for col in bool_columns: if col in self.df.columns: # 填充NaN为0 self.df[col] = self.df[col].fillna('0') # 将"-"转换为0/False self.df[col] = self.df[col].replace('-', '0').replace('', '0') # 转换为整数 self.df[col] = pd.to_numeric(self.df[col], errors='coerce').fillna(0).astype(int) # 转换为布尔值 self.df[col] = self.df[col].astype(bool) # 提取交易所信息 self.df['exchange'] = self.df['a_stock_code'].apply( lambda x: 'SH' if str(x).startswith('60') else 'SZ' if str(x).startswith(('00', '30')) else 'OTHER' ) # 验证A股代码格式 if 'a_stock_code' in self.df.columns: # 填充NaN为空字符串 self.df['a_stock_code'] = self.df['a_stock_code'].fillna('') # 转换为字符串 self.df['a_stock_code'] = self.df['a_stock_code'].astype(str) invalid_codes = self.df[~self.df['a_stock_code'].str.match(r'^\d{6}$')] if not invalid_codes.empty: logger.warning(f"发现 {len(invalid_codes)} 条无效的A股代码") logger.debug(f"无效代码示例: {invalid_codes['a_stock_code'].head().tolist()}") # 清理网址字段 - 修复NaN处理问题 if 'website' in self.df.columns: # 将NaN转换为空字符串 self.df['website'] = self.df['website'].fillna('') # 转换为字符串类型 self.df['website'] = self.df['website'].astype(str) # 执行字符串操作 self.df['website'] = self.df['website'].str.replace(' ', '').str.lower() # 安全地添加http前缀 self.df['website'] = self.df['website'].apply( lambda x: f'http://{x}' if x and not x.startswith('http') else x ) logger.info("数据清洗完成") return True except Exception as e: logger.error(f"数据清洗失败: {e}") return False def create_stocks_table(self, db: MySQLHelper) -> bool: """创建股票信息表(新版)""" create_table_sql = """ CREATE TABLE IF NOT EXISTS stocks_sz ( id INT AUTO_INCREMENT PRIMARY KEY, market_type VARCHAR(10) COMMENT '板块类型', company_full_name VARCHAR(100) NOT NULL COMMENT '公司全称', eng_name VARCHAR(150) COMMENT '英文名称', registered_address VARCHAR(200) COMMENT '注册地址', a_stock_code VARCHAR(6) NOT NULL COMMENT 'A股代码', a_stock_short_name VARCHAR(20) NOT NULL COMMENT 'A股简称', a_listing_date DATE COMMENT 'A股上市日期', a_total_shares BIGINT COMMENT 'A股总股本', a_circulating_shares BIGINT COMMENT 'A股流通股本', b_stock_code VARCHAR(6) COMMENT 'B股代码', b_stock_short_name VARCHAR(20) COMMENT 'B股简称', b_listing_date DATE COMMENT 'B股上市日期', b_total_shares BIGINT COMMENT 'B股总股本', b_circulating_shares BIGINT COMMENT 'B股流通股本', region VARCHAR(20) COMMENT '地区', province VARCHAR(20) COMMENT '省份', city VARCHAR(20) COMMENT '城市', industry VARCHAR(50) COMMENT '所属行业', website VARCHAR(100) COMMENT '公司网址', unprofitable BOOLEAN DEFAULT 0 COMMENT '未盈利', voting_rights_difference BOOLEAN DEFAULT 0 COMMENT '具有表决权差异安排', agreement_control_structure BOOLEAN DEFAULT 0 COMMENT '具有协议控制架构', exchange VARCHAR(2) COMMENT '交易所', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, UNIQUE KEY (a_stock_code) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='沪深股票详细信息表'; """ try: db.execute_update(create_table_sql) logger.info("股票信息表打开成功") return True except Exception as e: logger.error(f"创建表失败: {e}") return False def insert_data_to_db(self, db: MySQLHelper) -> bool: """将数据插入数据库""" if self.df is None or self.df.empty: logger.error("没有有效数据可插入") return False # 准备SQL语句(支持重复记录更新) insert_sql = """ INSERT INTO stocks_sz ( market_type, company_full_name, eng_name, registered_address, a_stock_code, a_stock_short_name, a_listing_date, a_total_shares, a_circulating_shares, b_stock_code, b_stock_short_name, b_listing_date, b_total_shares, b_circulating_shares, region, province, city, industry, website, unprofitable, voting_rights_difference, agreement_control_structure, exchange ) VALUES ( %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s ) ON DUPLICATE KEY UPDATE market_type = VALUES(market_type), company_full_name = VALUES(company_full_name), eng_name = VALUES(eng_name), registered_address = VALUES(registered_address), a_stock_short_name = VALUES(a_stock_short_name), a_listing_date = VALUES(a_listing_date), a_total_shares = VALUES(a_total_shares), a_circulating_shares = VALUES(a_circulating_shares), b_stock_code = VALUES(b_stock_code), b_stock_short_name = VALUES(b_stock_short_name), b_listing_date = VALUES(b_listing_date), b_total_shares = VALUES(b_total_shares), b_circulating_shares = VALUES(b_circulating_shares), region = VALUES(region), province = VALUES(province), city = VALUES(city), industry = VALUES(industry), website = VALUES(website), unprofitable = VALUES(unprofitable), voting_rights_difference = VALUES(voting_rights_difference), agreement_control_structure = VALUES(agreement_control_structure), exchange = VALUES(exchange) """ # 准备参数列表 params_list = [] for _, row in self.df.iterrows(): # 处理空值 def get_value(col, default=None): return row[col] if col in row and pd.notna(row[col]) else default # 处理数字字段 def get_numeric(col, default=0): value = get_value(col, default) try: return int(value) if value != '' and value is not None else default except: return default # 处理日期字段 def get_date(col, default='1970-01-01'): value = get_value(col, default) if value in ['', None, 'NaT']: return default return value # 处理布尔字段 def get_bool(col, default=False): value = get_value(col, default) if value in [True, '1', 1, 'Y', 'y', '是']: return True if value in [False, '0', 0, 'N', 'n', '否', '-', '']: return False return default params = ( get_value('market_type'), # market_type get_value('company_full_name', ''), # company_full_name get_value('eng_name'), # eng_name get_value('registered_address'), # registered_address get_value('a_stock_code', ''), # a_stock_code get_value('a_stock_short_name', ''), # a_stock_short_name get_date('a_listing_date'), # a_listing_date get_numeric('a_total_shares', 0), # a_total_shares get_numeric('a_circulating_shares', 0), # a_circulating_shares get_value('b_stock_code'), # b_stock_code get_value('b_stock_short_name'), # b_stock_short_name get_date('b_listing_date'), # b_listing_date get_numeric('b_total_shares', 0), # b_total_shares get_numeric('b_circulating_shares', 0), # b_circulating_shares get_value('region'), # region get_value('province'), # province get_value('city'), # city get_value('industry'), # industry get_value('website'), # website get_bool('unprofitable'), # unprofitable get_bool('voting_rights_difference'), # voting_rights_difference get_bool('agreement_control_structure'), # agreement_control_structure get_value('exchange', '') # exchange ) params_list.append(params) # 批量执行插入 try: total_rows = len(params_list) if total_rows == 0: logger.error("没有有效数据可插入") return False batch_size = 500 # 每批插入500条记录(因为字段较多) logger.info(f"开始插入数据,共 {total_rows} 条记录") # 分批插入,避免大事务问题 for i in range(0, total_rows, batch_size): batch_params = params_list[i:i+batch_size] affected_rows = db.execute_many(insert_sql, batch_params) logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录") logger.info(f"成功插入/更新 {total_rows} 条记录") return True except Exception as e: logger.error(f"插入数据失败: {e}") # 记录前5个参数以帮助调试 if params_list: logger.debug(f"前5个参数示例: {params_list[:5]}") return False def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool: """验证数据库中的数据""" try: # 检查记录总数 count_sql = "SELECT COUNT(*) AS total FROM stocks_sz" result = db.execute_query(count_sql) db_count = result[0]['total'] if result else 0 logger.info(f"数据库中共有 {db_count} 条记录") # 随机抽样检查 sample_sql = f""" SELECT a_stock_code, a_stock_short_name, a_listing_date, province, city FROM stocks_sz ORDER BY RAND() LIMIT {sample_size} """ samples = db.execute_query(sample_sql) logger.info("\n随机抽样记录:") for idx, sample in enumerate(samples, 1): logger.info(f"{idx}. {sample['a_stock_code']}: {sample['a_stock_short_name']} ({sample['a_listing_date']}) - {sample['province']}{sample['city']}") return True except Exception as e: logger.error(f"数据验证失败: {e}") return False def run_import(self) -> bool: """执行完整的导入流程""" logger.info(f"开始导入股票数据,数据目录: {self.data_dir}") start_time = datetime.now() # 1. 查找CSV文件 csv_file = self.find_csv_file() if not csv_file: return False # 2. 验证文件 if not self.validate_file(csv_file): return False # 3. 读取CSV数据 if not self.read_csv_data(csv_file): return False # 4. 清洗数据 if not self.clean_stock_data(): return False # 显示前5条数据 logger.info("\n前5条股票数据:") for i, row in self.df.head().iterrows(): logger.info(f"{row['a_stock_code']}: {row['a_stock_short_name']} ({row['a_listing_date']}) - {row['province']}{row['city']}") # 5. 连接数据库并导入 try: with MySQLHelper(**self.db_config) as db: # 5.1 创建表 if not self.create_stocks_table(db): return False # 5.2 插入数据 if not self.insert_data_to_db(db): return False # 5.3 验证数据 if not self.verify_data_in_db(db): return False except Exception as e: logger.error(f"数据库操作异常: {e}") return False # 计算执行时间 duration = datetime.now() - start_time logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}秒") return True if __name__ == "__main__": # 数据库配置 DB_CONFIG = { 'host': 'localhost', 'user': 'root', 'password': 'bzskmysql', 'database': 'fullmarketdata_a' } # 获取当前脚本所在目录 current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() # 设置数据目录 DATA_DIR = current_dir / "data" # 确保data目录存在 DATA_DIR.mkdir(exist_ok=True, parents=True) # 安装依赖 (如果chardet未安装) try: import chardet except ImportError: logger.info("安装chardet库以支持编码检测...") import subprocess subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"]) import chardet # 创建导入器并执行导入 importer = StockDataImporter(DATA_DIR, DB_CONFIG) if importer.run_import(): logger.info("股票数据导入成功!") else: logger.error("股票数据导入失败,请检查日志了解详情")