diff --git a/.gitignore b/.gitignore index ce9237b..e695a2e 100644 --- a/.gitignore +++ b/.gitignore @@ -264,3 +264,5 @@ compile_commands.json *_qmlcache.qrc + +data/ diff --git a/MySQLHelper.py b/MySQLHelper.py new file mode 100644 index 0000000..8055872 --- /dev/null +++ b/MySQLHelper.py @@ -0,0 +1,135 @@ +import pymysql +from pymysql import Error +from typing import List, Dict, Union, Optional, Tuple + +class MySQLHelper: + def __init__(self, host: str, user: str, password: str, database: str, + port: int = 3306, charset: str = 'utf8mb4'): + """ + 初始化MySQL连接参数 + :param host: 数据库地址 + :param user: 用户名 + :param password: 密码 + :param database: 数据库名 + :param port: 端口,默认3306 + :param charset: 字符集,默认utf8mb4 + """ + self.host = host + self.user = user + self.password = password + self.database = database + self.port = port + self.charset = charset + self.connection = None + self.cursor = None + + def connect(self) -> bool: + """ + 连接到MySQL数据库 + :return: 连接成功返回True,失败返回False + """ + try: + self.connection = pymysql.connect( + host=self.host, + user=self.user, + password=self.password, + database=self.database, + port=self.port, + charset=self.charset, + cursorclass=pymysql.cursors.DictCursor # 返回字典形式的结果 + ) + self.cursor = self.connection.cursor() + print("MySQL数据库连接成功") + return True + except Error as e: + print(f"连接MySQL数据库失败: {e}") + return False + + def close(self) -> None: + """ + 关闭数据库连接 + """ + if self.cursor: + self.cursor.close() + if self.connection: + self.connection.close() + print("MySQL数据库连接已关闭") + + def execute_query(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> List[Dict]: + """ + 执行查询操作 + :param sql: SQL语句 + :param params: 参数,可以是元组、列表或字典 + :return: 查询结果列表 + """ + try: + self.cursor.execute(sql, params) + return self.cursor.fetchall() + except Error as e: + print(f"查询执行失败: {e}") + return [] + + def execute_update(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> int: + """ + 执行更新操作(INSERT/UPDATE/DELETE) + :param sql: SQL语句 + :param params: 参数,可以是元组、列表或字典 + :return: 影响的行数 + """ + try: + affected_rows = self.cursor.execute(sql, params) + self.connection.commit() + return affected_rows + except Error as e: + self.connection.rollback() + print(f"更新执行失败: {e}") + return 0 + + def execute_many(self, sql: str, params_list: List[Union[Tuple, List, Dict]]) -> int: + """ + 批量执行操作 + :param sql: SQL语句 + :param params_list: 参数列表 + :return: 影响的行数 + """ + try: + affected_rows = self.cursor.executemany(sql, params_list) + self.connection.commit() + return affected_rows + except Error as e: + self.connection.rollback() + print(f"批量执行失败: {e}") + return 0 + + def get_one(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> Optional[Dict]: + """ + 获取单条记录 + :param sql: SQL语句 + :param params: 参数,可以是元组、列表或字典 + :return: 单条记录或None + """ + try: + self.cursor.execute(sql, params) + return self.cursor.fetchone() + except Error as e: + print(f"获取单条记录失败: {e}") + return None + + def table_exists(self, table_name: str) -> bool: + """ + 检查表是否存在 + :param table_name: 表名 + :return: 存在返回True,否则返回False + """ + sql = "SHOW TABLES LIKE %s" + result = self.execute_query(sql, (table_name,)) + return len(result) > 0 + + def __enter__(self): + """支持with上下文管理""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """支持with上下文管理""" + self.close() \ No newline at end of file diff --git a/exportExcelToDB_SH.py b/exportExcelToDB_SH.py new file mode 100644 index 0000000..abd8703 --- /dev/null +++ b/exportExcelToDB_SH.py @@ -0,0 +1,528 @@ +""" + 读取上海证券交易所官网股票列表数据写入数据库 + + 上海和深圳拿到的数据表头不一样,所以分开解析和存储 +""" + +import pandas as pd +import pymysql +from pymysql import Error +from pathlib import Path +import os +import logging +from datetime import datetime +import sys +import csv +import chardet # 用于检测文件编码 +from typing import List, Dict, Union, Tuple, Optional + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('Debug.log', encoding='utf-8'), # 关键在这里 + logging.StreamHandler() + ] +) + +logger = logging.getLogger('StockDataImporter') + +class MySQLHelper: + """MySQL 数据库操作助手类""" + def __init__(self, host: str, user: str, password: str, database: str, + port: int = 3306, charset: str = 'utf8mb4'): + self.host = host + self.user = user + self.password = password + self.database = database + self.port = port + self.charset = charset + self.connection = None + self.cursor = None + + def connect(self) -> bool: + try: + self.connection = pymysql.connect( + host=self.host, + user=self.user, + password=self.password, + database=self.database, + port=self.port, + charset=self.charset, + cursorclass=pymysql.cursors.DictCursor + ) + self.cursor = self.connection.cursor() + logger.info("MySQL数据库连接成功") + return True + except Error as e: + logger.error(f"连接MySQL数据库失败: {e}") + return False + + def close(self) -> None: + if self.cursor: + self.cursor.close() + if self.connection: + self.connection.close() + logger.info("MySQL数据库连接已关闭") + + def execute_query(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> List[Dict]: + try: + self.cursor.execute(sql, params) + return self.cursor.fetchall() + except Error as e: + logger.error(f"查询执行失败: {e}") + return [] + + def execute_update(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> int: + try: + affected_rows = self.cursor.execute(sql, params) + self.connection.commit() + return affected_rows + except Error as e: + self.connection.rollback() + logger.error(f"更新执行失败: {e}") + return 0 + + def execute_many(self, sql: str, params_list: List[Union[Tuple, List, Dict]]) -> int: + try: + affected_rows = self.cursor.executemany(sql, params_list) + self.connection.commit() + return affected_rows + except Error as e: + self.connection.rollback() + logger.error(f"批量执行失败: {e}") + return 0 + + def get_one(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> Optional[Dict]: + try: + self.cursor.execute(sql, params) + return self.cursor.fetchone() + except Error as e: + logger.error(f"获取单条记录失败: {e}") + return None + + def table_exists(self, table_name: str) -> bool: + sql = "SHOW TABLES LIKE %s" + result = self.execute_query(sql, (table_name,)) + return len(result) > 0 + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + +class StockDataImporter: + """股票数据导入工具(支持CSV)""" + + COLUMN_MAPPING = { + 'A股代码': 'a_stock_code', + 'B股代码': 'b_stock_code', + '证券简称': 'short_name', + '扩位证券简称': 'extended_name', + '公司英文全称': 'eng_name', + '上市日期': 'listing_date' + } + + def __init__(self, data_dir: Path, db_config: dict): + self.data_dir = data_dir + self.db_config = db_config + self.df = None + self.csv_file = None + self.encoding = 'utf-8' # 默认编码 + self.delimiter = ',' # 默认分隔符 + + def find_csv_file(self) -> Path: + """在data文件夹中查找CSV文件""" + # 查找所有CSV文件 + csv_files = list(self.data_dir.glob("GPLIST.csv")) + + if not csv_files: + logger.error(f"在 {self.data_dir} 中没有找到CSV文件") + return None + + # 如果有多个文件,选择最新的 + if len(csv_files) > 1: + csv_files.sort(key=os.path.getmtime, reverse=True) + logger.info(f"找到多个CSV文件,选择最新的: {csv_files[0].name}") + + return csv_files[0] + + def validate_file(self, file_path: Path) -> bool: + """验证CSV文件是否有效""" + try: + if not file_path.exists(): + logger.error(f"CSV文件不存在: {file_path}") + return False + + file_size = file_path.stat().st_size + if file_size == 0: + logger.error(f"CSV文件为空: {file_path}") + return False + + return True + except Exception as e: + logger.error(f"文件验证失败: {e}") + return False + + def detect_file_encoding(self, file_path: Path) -> str: + """检测文件编码""" + try: + # 读取文件开头部分进行编码检测 + with open(file_path, 'rb') as f: + raw_data = f.read(10000) # 读取前10KB + + # 使用chardet检测编码 + result = chardet.detect(raw_data) + encoding = result['encoding'] + confidence = result['confidence'] + + # 常见编码替代 + encoding_map = { + 'GB2312': 'GBK', + 'gb2312': 'GBK', + 'ISO-8859-1': 'latin1', + 'ascii': 'utf-8' + } + + # 应用映射 + encoding = encoding_map.get(encoding, encoding) + + logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})") + return encoding or 'utf-8' + except Exception as e: + logger.error(f"编码检测失败: {e}, 使用默认UTF-8") + return 'utf-8' + + def detect_csv_delimiter(self, file_path: Path) -> str: + """自动检测CSV分隔符""" + try: + # 使用检测到的编码打开文件 + with open(file_path, 'r', encoding=self.encoding) as f: + # 读取前5行 + lines = [f.readline() for _ in range(5) if f.readline()] + + # 尝试常见分隔符 + delimiters = [',', '\t', ';', '|'] + delimiter_counts = {} + + for delim in delimiters: + count = 0 + for line in lines: + count += line.count(delim) + delimiter_counts[delim] = count + + # 选择出现次数最多的分隔符 + best_delim = max(delimiter_counts, key=delimiter_counts.get) + + # 如果没有任何分隔符,则使用逗号 + if delimiter_counts[best_delim] == 0: + logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符") + return ',' + + logger.info(f"检测到分隔符: {repr(best_delim)}") + return best_delim + except Exception as e: + logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符") + return ',' + + def read_csv_data(self, file_path: Path) -> bool: + """从CSV文件读取数据""" + try: + # 1. 检测文件编码 + self.encoding = self.detect_file_encoding(file_path) + + # 2. 检测分隔符 + self.delimiter = self.detect_csv_delimiter(file_path) + + # 3. 读取CSV文件 + logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件") + + self.df = pd.read_csv( + file_path, + delimiter=self.delimiter, + dtype=str, + encoding=self.encoding, + on_bad_lines='warn', + quoting=csv.QUOTE_MINIMAL, + engine='python' # 更健壮的引擎 + ) + + # 检查是否读取到数据 + if self.df.empty: + logger.error("CSV文件没有包含有效数据") + return False + + # 重命名列 + self.df = self.df.rename(columns=self.COLUMN_MAPPING) + + # 移除可能存在的空行 + self.df = self.df.dropna(how='all') + + logger.info(f"成功读取CSV数据,共 {len(self.df)} 条记录") + return True + except UnicodeDecodeError: + # 尝试其他编码 + encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16'] + for enc in encodings_to_try: + try: + logger.warning(f"尝试使用 {enc} 编码读取文件") + self.df = pd.read_csv( + file_path, + delimiter=self.delimiter, + dtype=str, + encoding=enc + ) + self.encoding = enc + logger.info(f"成功使用 {enc} 编码读取文件") + return True + except: + continue + + logger.error("所有编码尝试均失败") + return False + except PermissionError: + logger.error(f"文件被占用,请关闭后重试: {file_path}") + return False + except Exception as e: + logger.error(f"读取CSV文件失败: {e}") + return False + + def clean_stock_data(self) -> bool: + """清洗股票数据""" + try: + # 处理B股代码:将'-'转换为None + self.df['b_stock_code'] = self.df['b_stock_code'].replace('-', None) + + # 格式化上市日期 + self.df['listing_date'] = pd.to_datetime( + self.df['listing_date'], + format='%Y%m%d', + errors='coerce' + ).dt.strftime('%Y-%m-%d') + + # 检查日期转换是否成功 + date_na_count = self.df['listing_date'].isna().sum() + if date_na_count > 0: + logger.warning(f"发现 {date_na_count} 条记录的上市日期格式不正确") + + # 提取交易所信息 + self.df['exchange'] = self.df['a_stock_code'].apply( + lambda x: 'SH' if str(x).startswith('60') else 'SZ' if str(x).startswith(('00', '30')) else 'OTHER' + ) + + # 验证A股代码格式 + invalid_codes = self.df[~self.df['a_stock_code'].astype(str).str.match(r'^\d{6}$')] + if not invalid_codes.empty: + logger.warning(f"发现 {len(invalid_codes)} 条无效的A股代码") + logger.debug(f"无效代码示例: {invalid_codes['a_stock_code'].head().tolist()}") + + logger.info("数据清洗完成") + return True + except Exception as e: + logger.error(f"数据清洗失败: {e}") + return False + + def create_stocks_table(self, db: MySQLHelper) -> bool: + """创建股票信息表""" + create_table_sql = """ + CREATE TABLE IF NOT EXISTS stocks_sh ( + a_stock_code VARCHAR(6) PRIMARY KEY COMMENT 'A股代码', + b_stock_code VARCHAR(6) COMMENT 'B股代码', + short_name VARCHAR(50) NOT NULL COMMENT '证券简称', + extended_name VARCHAR(100) COMMENT '扩位证券简称', + eng_name VARCHAR(150) COMMENT '公司英文全称', + listing_date DATE NOT NULL COMMENT '上市日期', + exchange VARCHAR(2) NOT NULL COMMENT '交易所', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间' + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='沪深股票信息表'; + """ + + try: + db.execute_update(create_table_sql) + logger.info("股票信息表创建成功") + return True + except Exception as e: + logger.error(f"创建表失败: {e}") + return False + + def insert_data_to_db(self, db: MySQLHelper) -> bool: + """将数据插入数据库""" + if self.df is None or self.df.empty: + logger.error("没有有效数据可插入") + return False + + # 准备SQL语句(支持重复记录更新) + insert_sql = """ + INSERT INTO stocks_sh ( + a_stock_code, b_stock_code, short_name, + extended_name, eng_name, listing_date, exchange + ) VALUES ( + %s, %s, %s, %s, %s, %s, %s + ) + ON DUPLICATE KEY UPDATE + b_stock_code = VALUES(b_stock_code), + short_name = VALUES(short_name), + extended_name = VALUES(extended_name), + eng_name = VALUES(eng_name), + listing_date = VALUES(listing_date), + exchange = VALUES(exchange) + """ + + # 准备参数列表 + params_list = [] + for _, row in self.df.iterrows(): + # 处理可能的NaN值 + listing_date = row['listing_date'] if pd.notna(row['listing_date']) else '1970-01-01' + + params_list.append(( + row['a_stock_code'], + row['b_stock_code'] if pd.notna(row['b_stock_code']) else None, + row['short_name'], + row['extended_name'] if pd.notna(row['extended_name']) else None, + row['eng_name'] if pd.notna(row['eng_name']) else None, + listing_date, + row['exchange'] + )) + + # 批量执行插入 + try: + total_rows = len(params_list) + if total_rows == 0: + logger.error("没有有效数据可插入") + return False + + batch_size = 1000 # 每批插入1000条记录 + + logger.info(f"开始插入数据,共 {total_rows} 条记录") + + # 分批插入,避免大事务问题 + for i in range(0, total_rows, batch_size): + batch_params = params_list[i:i+batch_size] + affected_rows = db.execute_many(insert_sql, batch_params) + logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录") + + logger.info(f"成功插入/更新 {total_rows} 条记录") + return True + except Exception as e: + logger.error(f"插入数据失败: {e}") + # 记录前5个参数以帮助调试 + if params_list: + logger.debug(f"前5个参数示例: {params_list[:5]}") + return False + + def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool: + """验证数据库中的数据""" + try: + # 检查记录总数 + count_sql = "SELECT COUNT(*) AS total FROM stocks_sh" + result = db.execute_query(count_sql) + db_count = result[0]['total'] if result else 0 + logger.info(f"数据库中共有 {db_count} 条记录") + + # 随机抽样检查 + sample_sql = f""" + SELECT a_stock_code, short_name, listing_date + FROM stocks_sh + ORDER BY RAND() + LIMIT {sample_size} + """ + samples = db.execute_query(sample_sql) + + logger.info("\n随机抽样记录:") + for idx, sample in enumerate(samples, 1): + logger.info(f"{idx}. {sample['a_stock_code']}: {sample['short_name']} ({sample['listing_date']})") + + return True + except Exception as e: + logger.error(f"数据验证失败: {e}") + return False + + def run_import(self) -> bool: + """执行完整的导入流程""" + logger.info(f"开始导入股票数据,数据目录: {self.data_dir}") + start_time = datetime.now() + + # 1. 查找CSV文件 + csv_file = self.find_csv_file() + if not csv_file: + return False + + # 2. 验证文件 + if not self.validate_file(csv_file): + return False + + # 3. 读取CSV数据 + if not self.read_csv_data(csv_file): + return False + + # 4. 清洗数据 + if not self.clean_stock_data(): + return False + + # 显示前5条数据 + logger.info("\n前5条股票数据:") + for i, row in self.df.head().iterrows(): + logger.info(f"{row['a_stock_code']}: {row['short_name']} ({row['listing_date']})") + + # 5. 连接数据库并导入 + try: + with MySQLHelper(**self.db_config) as db: + # 5.1 创建表 + if not self.create_stocks_table(db): + return False + + # 5.2 插入数据 + if not self.insert_data_to_db(db): + return False + + # 5.3 验证数据 + if not self.verify_data_in_db(db): + return False + except Exception as e: + logger.error(f"数据库操作异常: {e}") + return False + + # 计算执行时间 + duration = datetime.now() - start_time + logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}秒") + return True + +if __name__ == "__main__": + + # 数据库配置 + db_config = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'fullmarketdata_a' + } + + # 获取当前脚本所在目录 + current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() + + # 设置数据目录 + DATA_DIR = current_dir / "data" + + # 确保data目录存在 + DATA_DIR.mkdir(exist_ok=True, parents=True) + + # 安装依赖 (如果chardet未安装) + try: + import chardet + except ImportError: + logger.info("安装chardet库以支持编码检测...") + import subprocess + subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"]) + import chardet + + # 创建导入器并执行导入 + importer = StockDataImporter(DATA_DIR, db_config) + + if importer.run_import(): + logger.info("股票数据导入成功!") + else: + logger.error("股票数据导入失败,请检查日志了解详情") \ No newline at end of file diff --git a/exportExcelToDB_SZ.py b/exportExcelToDB_SZ.py new file mode 100644 index 0000000..138cb88 --- /dev/null +++ b/exportExcelToDB_SZ.py @@ -0,0 +1,673 @@ +""" + 下载的深圳交易所的数据 + + 表头和上海交易所的数据格式不一致,所以分开存储 +""" +import pandas as pd +import pymysql +from pymysql import Error +from pathlib import Path +import os +import logging +from datetime import datetime +import sys +import csv +import chardet +import re +from typing import List, Dict, Union, Tuple, Optional + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler('stock_data_import.log') + ] +) +logger = logging.getLogger('StockDataImporter') + +class MySQLHelper: + """MySQL 数据库操作助手类""" + def __init__(self, host: str, user: str, password: str, database: str, + port: int = 3306, charset: str = 'utf8mb4'): + self.host = host + self.user = user + self.password = password + self.database = database + self.port = port + self.charset = charset + self.connection = None + self.cursor = None + + def connect(self) -> bool: + try: + self.connection = pymysql.connect( + host=self.host, + user=self.user, + password=self.password, + database=self.database, + port=self.port, + charset=self.charset, + cursorclass=pymysql.cursors.DictCursor + ) + self.cursor = self.connection.cursor() + logger.info("MySQL数据库连接成功") + return True + except Error as e: + logger.error(f"连接MySQL数据库失败: {e}") + return False + + def close(self) -> None: + if self.cursor: + self.cursor.close() + if self.connection: + self.connection.close() + logger.info("MySQL数据库连接已关闭") + + def execute_query(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> List[Dict]: + try: + self.cursor.execute(sql, params) + return self.cursor.fetchall() + except Error as e: + logger.error(f"查询执行失败: {e}") + return [] + + def execute_update(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> int: + try: + affected_rows = self.cursor.execute(sql, params) + self.connection.commit() + return affected_rows + except Error as e: + self.connection.rollback() + logger.error(f"更新执行失败: {e}") + return 0 + + def execute_many(self, sql: str, params_list: List[Union[Tuple, List, Dict]]) -> int: + try: + affected_rows = self.cursor.executemany(sql, params_list) + self.connection.commit() + return affected_rows + except Error as e: + self.connection.rollback() + logger.error(f"批量执行失败: {e}") + return 0 + + def get_one(self, sql: str, params: Union[Tuple, List, Dict, None] = None) -> Optional[Dict]: + try: + self.cursor.execute(sql, params) + return self.cursor.fetchone() + except Error as e: + logger.error(f"获取单条记录失败: {e}") + return None + + def table_exists(self, table_name: str) -> bool: + sql = "SHOW TABLES LIKE %s" + result = self.execute_query(sql, (table_name,)) + return len(result) > 0 + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + +class StockDataImporter: + """股票数据导入工具(支持新版CSV格式)""" + + # 新版CSV列名映射 + COLUMN_MAPPING = { + '板块': 'market_type', + '公司全称': 'company_full_name', + '英文名称': 'eng_name', + '注册地址': 'registered_address', + 'A股代码': 'a_stock_code', + 'A股简称': 'a_stock_short_name', + 'A股上市日期': 'a_listing_date', + 'A股总股本': 'a_total_shares', + 'A股流通股本': 'a_circulating_shares', + 'B股代码': 'b_stock_code', + 'B股简称': 'b_stock_short_name', + 'B股上市日期': 'b_listing_date', + 'B股总股本': 'b_total_shares', + 'B股流通股本': 'b_circulating_shares', + '地区': 'region', + '省份': 'province', + '城市': 'city', + '所属行业': 'industry', + '公司网址': 'website', + '未盈利': 'unprofitable', + '具有表决权差异安排': 'voting_rights_difference', + '具有协议控制架构': 'agreement_control_structure' + } + + def __init__(self, data_dir: Path, db_config: dict): + self.data_dir = data_dir + self.db_config = db_config + self.df = None + self.csv_file = None + self.encoding = 'utf-8' # 默认编码 + self.delimiter = ',' # 默认分隔符 + + def find_csv_file(self) -> Path: + """在data文件夹中查找CSV文件""" + # 查找所有CSV文件 + csv_files = list(self.data_dir.glob("A股列表.csv")) + + if not csv_files: + logger.error(f"在 {self.data_dir} 中没有找到CSV文件") + return None + + # 如果有多个文件,选择最新的 + if len(csv_files) > 1: + csv_files.sort(key=os.path.getmtime, reverse=True) + logger.info(f"找到多个CSV文件,选择最新的: {csv_files[0].name}") + + return csv_files[0] + + def validate_file(self, file_path: Path) -> bool: + """验证CSV文件是否有效""" + try: + if not file_path.exists(): + logger.error(f"CSV文件不存在: {file_path}") + return False + + file_size = file_path.stat().st_size + if file_size == 0: + logger.error(f"CSV文件为空: {file_path}") + return False + + return True + except Exception as e: + logger.error(f"文件验证失败: {e}") + return False + + def detect_file_encoding(self, file_path: Path) -> str: + """检测文件编码""" + try: + # 读取文件开头部分进行编码检测 + with open(file_path, 'rb') as f: + raw_data = f.read(10000) # 读取前10KB + + # 使用chardet检测编码 + result = chardet.detect(raw_data) + encoding = result['encoding'] + confidence = result['confidence'] + + # 常见编码替代 + encoding_map = { + 'GB2312': 'GBK', + 'gb2312': 'GBK', + 'ISO-8859-1': 'latin1', + 'ascii': 'utf-8' + } + + # 应用映射 + encoding = encoding_map.get(encoding, encoding) + + logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})") + return encoding or 'utf-8' + except Exception as e: + logger.error(f"编码检测失败: {e}, 使用默认UTF-8") + return 'utf-8' + + def detect_csv_delimiter(self, file_path: Path) -> str: + """自动检测CSV分隔符""" + try: + # 使用检测到的编码打开文件 + with open(file_path, 'r', encoding=self.encoding) as f: + # 读取前5行 + lines = [f.readline() for _ in range(5) if f.readline()] + + # 尝试常见分隔符 + delimiters = [',', '\t', ';', '|'] + delimiter_counts = {} + + for delim in delimiters: + count = 0 + for line in lines: + count += line.count(delim) + delimiter_counts[delim] = count + + # 选择出现次数最多的分隔符 + best_delim = max(delimiter_counts, key=delimiter_counts.get) + + # 如果没有任何分隔符,则使用逗号 + if delimiter_counts[best_delim] == 0: + logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符") + return ',' + + logger.info(f"检测到分隔符: {repr(best_delim)}") + return best_delim + except Exception as e: + logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符") + return ',' + + def read_csv_data(self, file_path: Path) -> bool: + """从CSV文件读取数据""" + try: + # 1. 检测文件编码 + self.encoding = self.detect_file_encoding(file_path) + + # 2. 检测分隔符 + self.delimiter = self.detect_csv_delimiter(file_path) + + # 3. 读取CSV文件 + logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件") + + self.df = pd.read_csv( + file_path, + delimiter=self.delimiter, + dtype=str, + encoding=self.encoding, + on_bad_lines='warn', + quoting=csv.QUOTE_MINIMAL, + engine='python' # 更健壮的引擎 + ) + + # 检查是否读取到数据 + if self.df.empty: + logger.error("CSV文件没有包含有效数据") + return False + + # 重命名列 + self.df = self.df.rename(columns=self.COLUMN_MAPPING) + + # 移除可能存在的空行 + self.df = self.df.dropna(how='all') + + logger.info(f"成功读取CSV数据,共 {len(self.df)} 条记录") + return True + except UnicodeDecodeError: + # 尝试其他编码 + encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16'] + for enc in encodings_to_try: + try: + logger.warning(f"尝试使用 {enc} 编码读取文件") + self.df = pd.read_csv( + file_path, + delimiter=self.delimiter, + dtype=str, + encoding=enc + ) + self.encoding = enc + logger.info(f"成功使用 {enc} 编码读取文件") + return True + except: + continue + + logger.error("所有编码尝试均失败") + return False + except PermissionError: + logger.error(f"文件被占用,请关闭后重试: {file_path}") + return False + except Exception as e: + logger.error(f"读取CSV文件失败: {e}") + return False + + def clean_stock_data(self) -> bool: + """清洗股票数据(修复了website字段的NaN处理问题)""" + try: + # 处理数字字段中的逗号 + numeric_columns = [ + 'a_total_shares', 'a_circulating_shares', + 'b_total_shares', 'b_circulating_shares' + ] + + for col in numeric_columns: + if col in self.df.columns: + # 填充NaN为空字符串 + self.df[col] = self.df[col].fillna('') + # 转换为字符串 + self.df[col] = self.df[col].astype(str) + # 移除逗号和空格 + self.df[col] = self.df[col].str.replace(',', '').str.replace(' ', '') + + # 格式化日期字段 + date_columns = ['a_listing_date', 'b_listing_date'] + for col in date_columns: + if col in self.df.columns: + # 填充NaN为空字符串 + self.df[col] = self.df[col].fillna('') + # 转换为datetime,无效日期转为NaT + self.df[col] = pd.to_datetime( + self.df[col], + errors='coerce' + ).dt.strftime('%Y-%m-%d') + # 将NaT转换为空字符串 + self.df[col] = self.df[col].replace('NaT', '') + + # 处理布尔字段 + bool_columns = ['unprofitable', 'voting_rights_difference', 'agreement_control_structure'] + for col in bool_columns: + if col in self.df.columns: + # 填充NaN为0 + self.df[col] = self.df[col].fillna('0') + # 将"-"转换为0/False + self.df[col] = self.df[col].replace('-', '0').replace('', '0') + # 转换为整数 + self.df[col] = pd.to_numeric(self.df[col], errors='coerce').fillna(0).astype(int) + # 转换为布尔值 + self.df[col] = self.df[col].astype(bool) + + # 提取交易所信息 + self.df['exchange'] = self.df['a_stock_code'].apply( + lambda x: 'SH' if str(x).startswith('60') else 'SZ' if str(x).startswith(('00', '30')) else 'OTHER' + ) + + # 验证A股代码格式 + if 'a_stock_code' in self.df.columns: + # 填充NaN为空字符串 + self.df['a_stock_code'] = self.df['a_stock_code'].fillna('') + # 转换为字符串 + self.df['a_stock_code'] = self.df['a_stock_code'].astype(str) + + invalid_codes = self.df[~self.df['a_stock_code'].str.match(r'^\d{6}$')] + if not invalid_codes.empty: + logger.warning(f"发现 {len(invalid_codes)} 条无效的A股代码") + logger.debug(f"无效代码示例: {invalid_codes['a_stock_code'].head().tolist()}") + + # 清理网址字段 - 修复NaN处理问题 + if 'website' in self.df.columns: + # 将NaN转换为空字符串 + self.df['website'] = self.df['website'].fillna('') + # 转换为字符串类型 + self.df['website'] = self.df['website'].astype(str) + + # 执行字符串操作 + self.df['website'] = self.df['website'].str.replace(' ', '').str.lower() + + # 安全地添加http前缀 + self.df['website'] = self.df['website'].apply( + lambda x: f'http://{x}' if x and not x.startswith('http') else x + ) + + logger.info("数据清洗完成") + return True + except Exception as e: + logger.error(f"数据清洗失败: {e}") + return False + + def create_stocks_table(self, db: MySQLHelper) -> bool: + """创建股票信息表(新版)""" + create_table_sql = """ + CREATE TABLE IF NOT EXISTS stocks_sz ( + id INT AUTO_INCREMENT PRIMARY KEY, + market_type VARCHAR(10) COMMENT '板块类型', + company_full_name VARCHAR(100) NOT NULL COMMENT '公司全称', + eng_name VARCHAR(150) COMMENT '英文名称', + registered_address VARCHAR(200) COMMENT '注册地址', + a_stock_code VARCHAR(6) NOT NULL COMMENT 'A股代码', + a_stock_short_name VARCHAR(20) NOT NULL COMMENT 'A股简称', + a_listing_date DATE COMMENT 'A股上市日期', + a_total_shares BIGINT COMMENT 'A股总股本', + a_circulating_shares BIGINT COMMENT 'A股流通股本', + b_stock_code VARCHAR(6) COMMENT 'B股代码', + b_stock_short_name VARCHAR(20) COMMENT 'B股简称', + b_listing_date DATE COMMENT 'B股上市日期', + b_total_shares BIGINT COMMENT 'B股总股本', + b_circulating_shares BIGINT COMMENT 'B股流通股本', + region VARCHAR(20) COMMENT '地区', + province VARCHAR(20) COMMENT '省份', + city VARCHAR(20) COMMENT '城市', + industry VARCHAR(50) COMMENT '所属行业', + website VARCHAR(100) COMMENT '公司网址', + unprofitable BOOLEAN DEFAULT 0 COMMENT '未盈利', + voting_rights_difference BOOLEAN DEFAULT 0 COMMENT '具有表决权差异安排', + agreement_control_structure BOOLEAN DEFAULT 0 COMMENT '具有协议控制架构', + exchange VARCHAR(2) COMMENT '交易所', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY (a_stock_code) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='沪深股票详细信息表'; + """ + + try: + db.execute_update(create_table_sql) + logger.info("股票信息表创建成功") + return True + except Exception as e: + logger.error(f"创建表失败: {e}") + return False + + def insert_data_to_db(self, db: MySQLHelper) -> bool: + """将数据插入数据库""" + if self.df is None or self.df.empty: + logger.error("没有有效数据可插入") + return False + + # 准备SQL语句(支持重复记录更新) + insert_sql = """ + INSERT INTO stocks_sz ( + market_type, company_full_name, eng_name, registered_address, + a_stock_code, a_stock_short_name, a_listing_date, a_total_shares, a_circulating_shares, + b_stock_code, b_stock_short_name, b_listing_date, b_total_shares, b_circulating_shares, + region, province, city, industry, website, + unprofitable, voting_rights_difference, agreement_control_structure, + exchange + ) VALUES ( + %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s + ) + ON DUPLICATE KEY UPDATE + market_type = VALUES(market_type), + company_full_name = VALUES(company_full_name), + eng_name = VALUES(eng_name), + registered_address = VALUES(registered_address), + a_stock_short_name = VALUES(a_stock_short_name), + a_listing_date = VALUES(a_listing_date), + a_total_shares = VALUES(a_total_shares), + a_circulating_shares = VALUES(a_circulating_shares), + b_stock_code = VALUES(b_stock_code), + b_stock_short_name = VALUES(b_stock_short_name), + b_listing_date = VALUES(b_listing_date), + b_total_shares = VALUES(b_total_shares), + b_circulating_shares = VALUES(b_circulating_shares), + region = VALUES(region), + province = VALUES(province), + city = VALUES(city), + industry = VALUES(industry), + website = VALUES(website), + unprofitable = VALUES(unprofitable), + voting_rights_difference = VALUES(voting_rights_difference), + agreement_control_structure = VALUES(agreement_control_structure), + exchange = VALUES(exchange) + """ + + # 准备参数列表 + params_list = [] + for _, row in self.df.iterrows(): + # 处理空值 + def get_value(col, default=None): + return row[col] if col in row and pd.notna(row[col]) else default + + # 处理数字字段 + def get_numeric(col, default=0): + value = get_value(col, default) + try: + return int(value) if value != '' and value is not None else default + except: + return default + + # 处理日期字段 + def get_date(col, default='1970-01-01'): + value = get_value(col, default) + if value in ['', None, 'NaT']: + return default + return value + + # 处理布尔字段 + def get_bool(col, default=False): + value = get_value(col, default) + if value in [True, '1', 1, 'Y', 'y', '是']: + return True + if value in [False, '0', 0, 'N', 'n', '否', '-', '']: + return False + return default + + params = ( + get_value('market_type'), # market_type + get_value('company_full_name', ''), # company_full_name + get_value('eng_name'), # eng_name + get_value('registered_address'), # registered_address + get_value('a_stock_code', ''), # a_stock_code + get_value('a_stock_short_name', ''), # a_stock_short_name + get_date('a_listing_date'), # a_listing_date + get_numeric('a_total_shares', 0), # a_total_shares + get_numeric('a_circulating_shares', 0), # a_circulating_shares + get_value('b_stock_code'), # b_stock_code + get_value('b_stock_short_name'), # b_stock_short_name + get_date('b_listing_date'), # b_listing_date + get_numeric('b_total_shares', 0), # b_total_shares + get_numeric('b_circulating_shares', 0), # b_circulating_shares + get_value('region'), # region + get_value('province'), # province + get_value('city'), # city + get_value('industry'), # industry + get_value('website'), # website + get_bool('unprofitable'), # unprofitable + get_bool('voting_rights_difference'), # voting_rights_difference + get_bool('agreement_control_structure'), # agreement_control_structure + get_value('exchange', '') # exchange + ) + + params_list.append(params) + + # 批量执行插入 + try: + total_rows = len(params_list) + if total_rows == 0: + logger.error("没有有效数据可插入") + return False + + batch_size = 500 # 每批插入500条记录(因为字段较多) + + logger.info(f"开始插入数据,共 {total_rows} 条记录") + + # 分批插入,避免大事务问题 + for i in range(0, total_rows, batch_size): + batch_params = params_list[i:i+batch_size] + affected_rows = db.execute_many(insert_sql, batch_params) + logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录") + + logger.info(f"成功插入/更新 {total_rows} 条记录") + return True + except Exception as e: + logger.error(f"插入数据失败: {e}") + # 记录前5个参数以帮助调试 + if params_list: + logger.debug(f"前5个参数示例: {params_list[:5]}") + return False + + def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool: + """验证数据库中的数据""" + try: + # 检查记录总数 + count_sql = "SELECT COUNT(*) AS total FROM stocks_sz" + result = db.execute_query(count_sql) + db_count = result[0]['total'] if result else 0 + logger.info(f"数据库中共有 {db_count} 条记录") + + # 随机抽样检查 + sample_sql = f""" + SELECT a_stock_code, a_stock_short_name, a_listing_date, province, city + FROM stocks_sz + ORDER BY RAND() + LIMIT {sample_size} + """ + samples = db.execute_query(sample_sql) + + logger.info("\n随机抽样记录:") + for idx, sample in enumerate(samples, 1): + logger.info(f"{idx}. {sample['a_stock_code']}: {sample['a_stock_short_name']} ({sample['a_listing_date']}) - {sample['province']}{sample['city']}") + + return True + except Exception as e: + logger.error(f"数据验证失败: {e}") + return False + + def run_import(self) -> bool: + """执行完整的导入流程""" + logger.info(f"开始导入股票数据,数据目录: {self.data_dir}") + start_time = datetime.now() + + # 1. 查找CSV文件 + csv_file = self.find_csv_file() + if not csv_file: + return False + + # 2. 验证文件 + if not self.validate_file(csv_file): + return False + + # 3. 读取CSV数据 + if not self.read_csv_data(csv_file): + return False + + # 4. 清洗数据 + if not self.clean_stock_data(): + return False + + # 显示前5条数据 + logger.info("\n前5条股票数据:") + for i, row in self.df.head().iterrows(): + logger.info(f"{row['a_stock_code']}: {row['a_stock_short_name']} ({row['a_listing_date']}) - {row['province']}{row['city']}") + + # 5. 连接数据库并导入 + try: + with MySQLHelper(**self.db_config) as db: + # 5.1 创建表 + if not self.create_stocks_table(db): + return False + + # 5.2 插入数据 + if not self.insert_data_to_db(db): + return False + + # 5.3 验证数据 + if not self.verify_data_in_db(db): + return False + except Exception as e: + logger.error(f"数据库操作异常: {e}") + return False + + # 计算执行时间 + duration = datetime.now() - start_time + logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}秒") + return True + +if __name__ == "__main__": + + # 数据库配置 + DB_CONFIG = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'fullmarketdata_a' + } + + # 获取当前脚本所在目录 + current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() + + # 设置数据目录 + DATA_DIR = current_dir / "data" + + # 确保data目录存在 + DATA_DIR.mkdir(exist_ok=True, parents=True) + + # 安装依赖 (如果chardet未安装) + try: + import chardet + except ImportError: + logger.info("安装chardet库以支持编码检测...") + import subprocess + subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"]) + import chardet + + # 创建导入器并执行导入 + importer = StockDataImporter(DATA_DIR, DB_CONFIG) + + if importer.run_import(): + logger.info("股票数据导入成功!") + else: + logger.error("股票数据导入失败,请检查日志了解详情") \ No newline at end of file diff --git a/getStockList.py b/getStockList.py new file mode 100644 index 0000000..2f4ec78 --- /dev/null +++ b/getStockList.py @@ -0,0 +1,5 @@ +import akshare as ak + +# 获取沪深 A 股股票列表 +stock_list = ak.stock_info_a_code_name() +print(stock_list) \ No newline at end of file