diff --git a/baostockDataScraping.py b/baostockDataScraping.py new file mode 100644 index 0000000..4351ac9 --- /dev/null +++ b/baostockDataScraping.py @@ -0,0 +1,270 @@ +from MySQLHelper import MySQLHelper +from LogHelper import LogHelper +import logging +import pandas as pd +import re +import time +from datetime import datetime, timedelta +from tqdm import tqdm +import baostock as bs + +# 创建配置实例 +logHelper = LogHelper( + level=logging.DEBUG, + format='%(asctime)s [%(levelname)s] %(message)s' +) + +logHelper.add_console_handler() +logHelper.add_file_handler('Debug.log') +logHelper.setup() +logger = logging.getLogger('StockDataImporter') + +# 数据库配置 +DB_CONFIG = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'fullmarketdata_a', + 'port': 3306, + 'charset': 'utf8mb4' +} + +DB_CONFIG_1D = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'klinedata_1d_ma_bao', + 'port': 3306, + 'charset': 'utf8mb4' +} + +def get_SH_stock_codes_with_context() -> list: + """获取上海股票代码""" + with MySQLHelper(**DB_CONFIG) as db: + try: + results = db.execute_query("SELECT a_stock_code FROM stocks_sh") + return [row['a_stock_code'] for row in results if row['a_stock_code']] + except Exception as e: + logger.error(f"获取股票代码时出错: {e}") + return [] + +def get_SZ_stock_codes_with_context() -> list: + """获取深圳股票代码""" + with MySQLHelper(**DB_CONFIG) as db: + try: + results = db.execute_query("SELECT a_stock_code FROM stocks_sz") + return [row['a_stock_code'] for row in results if row['a_stock_code']] + except Exception as e: + logger.error(f"获取股票代码时出错: {e}") + return [] + +def generate_table_name(code: str) -> str: + """根据股票代码生成表名""" + # 移除可能的前缀 + clean_code = code.replace("sh.", "").replace("sz.", "") + + if clean_code.startswith('6'): + return f"sh_{clean_code}" + elif clean_code.startswith(('0', '3')): + return f"sz_{clean_code}" + elif clean_code.startswith(('4', '8')): + return f"bj_{clean_code}" + return f"unknown_{clean_code}" + +def create_stock_table(db: MySQLHelper, table_name: str) -> bool: + """创建股票数据表(匹配baostock数据结构)""" + if not re.match(r"^[a-z]{2}_[0-9]{6}$", table_name): + logger.error(f"表名 '{table_name}' 不符合命名规则") + return False + + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS `{table_name}` ( + `date` DATE NOT NULL COMMENT '日期', + `code` VARCHAR(10) NOT NULL COMMENT '代码', + `open` DECIMAL(10, 2) NOT NULL COMMENT '开盘价', + `high` DECIMAL(10, 2) NOT NULL COMMENT '最高价', + `low` DECIMAL(10, 2) NOT NULL COMMENT '最低价', + `close` DECIMAL(10, 2) NOT NULL COMMENT '收盘价', + `preclose` DECIMAL(10, 2) NOT NULL COMMENT '前收盘价', + `volume` BIGINT NOT NULL COMMENT '成交量(股)', + `amount` DECIMAL(20, 2) NOT NULL COMMENT '成交额(元)', + `adjustflag` TINYINT NOT NULL COMMENT '复权状态', + `turn` DECIMAL(10, 2) NOT NULL COMMENT '换手率(%)', + `tradestatus` TINYINT NOT NULL COMMENT '交易状态', + `pctChg` DECIMAL(10, 2) NOT NULL COMMENT '涨跌幅(%)', + `isST` TINYINT NOT NULL COMMENT '是否ST股', + PRIMARY KEY (`date`), + INDEX `idx_date` (`date`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票日K数据表'; + """ + + try: + db.execute_update(create_table_sql) + logger.info(f"成功创建表: {table_name}") + return True + except Exception as e: + logger.error(f"创建表 {table_name} 失败: {e}") + return False + +def save_stock_data_to_db(db: MySQLHelper, df: pd.DataFrame, table_name: str) -> int: + """将股票数据保存到数据库表中""" + if df.empty: + logger.warning(f"表 {table_name}: 无数据可保存") + return 0 + + if not re.match(r"^[a-z]{2}_[0-9]{6}$", table_name): + logger.error(f"表名 '{table_name}' 不符合命名规则") + return 0 + + # 准备插入SQL(匹配表结构) + insert_sql = f""" + INSERT INTO `{table_name}` ( + date, code, open, high, low, close, preclose, + volume, amount, adjustflag, turn, tradestatus, pctChg, isST + ) VALUES ( + %s, %s, %s, %s, %s, %s, %s, + %s, %s, %s, %s, %s, %s, %s + ) ON DUPLICATE KEY UPDATE + date = VALUES(date), + code = VALUES(code), + open = VALUES(open), + high = VALUES(high), + low = VALUES(low), + close = VALUES(close), + preclose = VALUES(preclose), + volume = VALUES(volume), + amount = VALUES(amount), + adjustflag = VALUES(adjustflag), + turn = VALUES(turn), + tradestatus = VALUES(tradestatus), + pctChg = VALUES(pctChg), + isST = VALUES(isST) + """ + + # 准备数据 + data_to_insert = [] + for _, row in df.iterrows(): + + # m没有交易,则跳过该数据 + if row['tradestatus'] == '0': + continue + + # 转换数据类型 + try: + data_to_insert.append(( + row['date'], + row['code'], + float(row['open']), + float(row['high']), + float(row['low']), + float(row['close']), + float(row['preclose']), + int(row['volume']), # 注意:baostock返回的是股数 + float(row['amount']), + int(row['adjustflag']), + float(row['turn']), + int(row['tradestatus']), + float(row['pctChg']), + int(row['isST']) + )) + except Exception as e: + logger.error(f"处理行数据时出错: {e}\n行数据: {row}") + + # 批量插入 + if not data_to_insert: + return 0 + + try: + affected_rows = db.execute_many(insert_sql, data_to_insert) + logger.info(f"表 {table_name}: 成功插入/更新 {affected_rows} 条记录") + return affected_rows + except Exception as e: + logger.error(f"保存数据到表 {table_name} 失败: {e}") + return 0 + +if __name__ == "__main__": + # 登陆baostock + lg = bs.login() + logger.info(f'登陆返回: error_code={lg.error_code}, error_msg={lg.error_msg}') + + if lg.error_code != '0': + logger.error("baostock登录失败,程序终止") + exit(1) + + # 读取股票代码 + logger.info("开始读取股票代码") + sh_codes = get_SH_stock_codes_with_context() + sz_codes = get_SZ_stock_codes_with_context() + logger.info(f"获取到上海股票数量: {len(sh_codes)}, 深圳股票数量: {len(sz_codes)}") + + # 连接数据库 + db_1d = MySQLHelper(**DB_CONFIG_1D) + if not db_1d.connect(): + logger.error("数据库连接失败") + exit(1) + + # # 设置日期范围 + # end_date = datetime.now().strftime("%Y-%m-%d") + # start_date = (datetime.now() - timedelta(days= 5 * 365)).strftime("%Y-%m-%d") + # logger.info(f"获取数据时间范围: {start_date} 至 {end_date}") + + # 每个交易日结束后,都需压迫抓取一次数据,方便后续处理 + # 设置日期范围:(当天开始,下一天结束),无需重复抓取 + end_date = datetime.now().strftime("%Y-%m-%d") + start_date = (datetime.now() - timedelta(days= 1)).strftime("%Y-%m-%d") + logger.info(f"获取数据时间范围: {start_date} 至 {end_date}") + + # 获取所有股票代码 + all_codes = [] + for code in sh_codes: + all_codes.append(("sh", code)) + for code in sz_codes: + all_codes.append(("sz", code)) + + logger.info(f"总股票数量: {len(all_codes)}") + + # 使用tqdm创建进度条 + for exchange, code in tqdm(all_codes, desc="下载股票数据", unit="支"): + full_code = f"{exchange}.{code}" + table_name = generate_table_name(full_code) + + # 创建表(如果不存在) + if not create_stock_table(db_1d, table_name): + logger.error(f"跳过股票 {full_code}") + continue + + # 获取数据 + rs = bs.query_history_k_data_plus( + full_code, + "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST", + start_date=start_date, + end_date=end_date, + frequency="d", + adjustflag="2" # 前复权 + ) + + if rs.error_code != '0': + logger.error(f"获取 {full_code} 数据失败: {rs.error_msg}") + continue + + # 处理数据 + data_list = [] + while rs.next(): + data_list.append(rs.get_row_data()) + + if not data_list: + logger.warning(f"股票 {full_code} 无数据返回") + continue + + df = pd.DataFrame(data_list, columns=rs.fields) + + # 保存数据 + save_stock_data_to_db(db_1d, df, table_name) + + # 添加延迟,避免请求过快 + time.sleep(5) # 适当降低延迟 + + # 关闭连接 + bs.logout() + db_1d.close() + logger.info("程序执行完成") \ No newline at end of file diff --git a/webDataScraping.py b/webDataScraping.py new file mode 100644 index 0000000..1866a13 --- /dev/null +++ b/webDataScraping.py @@ -0,0 +1,440 @@ +from MySQLHelper import MySQLHelper # 导入我们创建的助手类 +from LogHelper import LogHelper +import logging +import pandas as pd +import akshare as ak +import re +import os +import time +from datetime import datetime, timedelta +from tqdm import tqdm # 用于显示进度条 + +# 创建配置实例 +logHelper = LogHelper( + level=logging.DEBUG, # 设置日志级别为 DEBUG + format='%(asctime)s [%(levelname)s] %(message)s' # 自定义格式 +) + +# # 添加处理器 +logHelper.add_console_handler() # 默认输出到 stdout +logHelper.add_file_handler('Debug.log') # 添加文件日志 + +# # 应用配置 +logHelper.setup() +logger = logging.getLogger('StockDataImporter') + +# 数据库配置信息 股票列表 +DB_CONFIG = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'fullmarketdata_a', + 'port': 3306, + 'charset': 'utf8mb4' +} + +# 日K数据库 +DB_CONFIG_1D = { + 'host': 'localhost', + 'user': 'root', + 'password': 'bzskmysql', + 'database': 'klinedata_1d_ma', + 'port': 3306, + 'charset': 'utf8mb4' +} + +# 方法1:显式连接和关闭 +def get_SH_stock_codes() -> list: + """ + 从数据库中获取所有 a_stock_code 值 + + 返回: + list: 包含所有股票代码的列表 + """ + # 创建数据库助手实例 + db = MySQLHelper(**DB_CONFIG) + + try: + # 连接数据库 + if not db.connect(): + logger.error("数据库连接失败") + return [] + + # 执行查询 + results = db.execute_query("SELECT a_stock_code FROM stocks_sh") + + # 提取股票代码 + stock_codes = [row['a_stock_code'] for row in results if row['a_stock_code']] + + return stock_codes + + except Exception as e: + logger.error(f"获取股票代码时出错: {e}") + return [] + + finally: + # 确保关闭数据库连接 + db.close() + +def get_SZ_stock_codes() -> list: + """ + 从数据库中获取所有 a_stock_code 值 + + 返回: + list: 包含所有股票代码的列表 + """ + # 创建数据库助手实例 + db = MySQLHelper(**DB_CONFIG) + + try: + # 连接数据库 + if not db.connect(): + logger.error("数据库连接失败") + return [] + + # 执行查询 + results = db.execute_query("SELECT a_stock_code FROM stocks_sz") + + # 提取股票代码 + stock_codes = [row['a_stock_code'] for row in results if row['a_stock_code']] + + return stock_codes + + except Exception as e: + logger.error(f"获取股票代码时出错: {e}") + return [] + + finally: + # 确保关闭数据库连接 + db.close() + +# 方法2:使用上下文管理器(推荐) +def get_SH_stock_codes_with_context() -> list: + """ + 使用上下文管理器获取所有 a_stock_code 值 + + 返回: + list: 包含所有股票代码的列表 + """ + # 使用上下文管理器自动处理连接 + with MySQLHelper(**DB_CONFIG) as db: + try: + # 执行查询 + results = db.execute_query("SELECT a_stock_code FROM stocks_sh") + + # 提取股票代码 + return [row['a_stock_code'] for row in results if row['a_stock_code']] + + except Exception as e: + logger.error(f"获取股票代码时出错: {e}") + return [] + +def get_SZ_stock_codes_with_context() -> list: + """ + 使用上下文管理器获取所有 a_stock_code 值 + + 返回: + list: 包含所有股票代码的列表 + """ + # 使用上下文管理器自动处理连接 + with MySQLHelper(**DB_CONFIG) as db: + try: + # 执行查询 + results = db.execute_query("SELECT a_stock_code FROM stocks_sz") + + # 提取股票代码 + return [row['a_stock_code'] for row in results if row['a_stock_code']] + + except Exception as e: + logger.error(f"获取股票代码时出错: {e}") + return [] + +def get_daily_k_data(stock_code: str, start_date: str, end_date: str) -> pd.DataFrame: + """ + 获取单只股票的日K线数据 + + 参数: + stock_code: 格式化后的股票代码 (如 sh600000) + start_date: 开始日期 (YYYYMMDD) + end_date: 结束日期 (YYYYMMDD) + + 返回: + DataFrame: 包含日K线数据的DataFrame + """ + try: + # 获取股票历史行情数据 + df = ak.stock_zh_a_hist( + symbol=stock_code, + period="daily", + start_date=start_date, + end_date=end_date, + adjust="qfq" # 前复权 + ) + + # 如果数据为空,尝试使用原始代码 + if df.empty and not stock_code.startswith(('sh', 'sz', 'bj')): + logger.info(f"尝试使用原始代码: {stock_code}") + df = ak.stock_zh_a_hist( + symbol=stock_code, + period="daily", + start_date=start_date, + end_date=end_date, + adjust="qfq" + ) + + # 重命名列 + if not df.empty: + df.columns = [ + 'date', 'open', 'close', 'high', 'low', + 'volume', 'amount', 'amplitude', 'change_percent', + 'change_amount', 'turnover' + ] + df['code'] = stock_code # 添加股票代码列 + + return df + + except Exception as e: + logger.error(f"获取 {stock_code} 日K数据时出错: {e}") + return pd.DataFrame() + +def format_stock_code(code: str) -> str: + """ + 格式化股票代码为akshare需要的格式 + + 规则: + - 6开头: 上海证券交易所 (sh) + - 0或3开头: 深圳证券交易所 (sz) + - 4或8开头: 北京证券交易所 (bj) + + 返回: 交易所前缀 + 股票代码 + """ + # 如果代码已经是带前缀的格式,直接返回 + if code.startswith(('sh', 'sz', 'bj')): + return code + + # 根据数字前缀判断交易所 + if code.startswith('6'): + return f"sh{code}" + elif code.startswith(('0', '3')): + return f"sz{code}" + elif code.startswith(('4', '8')): + return f"bj{code}" + else: + logger.error(f"无法识别的股票代码格式: {code}") + return code # 返回原始格式,让akshare尝试处理 + +def get_daily_k_data(stock_code: str, start_date: str, end_date: str) -> pd.DataFrame: + """ + 获取单只股票的日K线数据 + + 参数: + stock_code: 格式化后的股票代码 (如 sh600000) + start_date: 开始日期 (YYYYMMDD) + end_date: 结束日期 (YYYYMMDD) + + 返回: + DataFrame: 包含日K线数据的DataFrame + """ + try: + # 获取股票历史行情数据 + df = ak.stock_zh_a_hist( + symbol=stock_code, + period="daily", + start_date=start_date, + end_date=end_date, + adjust="qfq" # 前复权 + ) + + # 重命名列 + if not df.empty: + df.columns = [ + 'date', 'code', 'open', 'close', 'high', 'low', + 'volume', 'amount', 'amplitude', 'change_percent', + 'change_amount', 'turnover' + ] + return df + + except Exception as e: + logger.error(f"获取 {stock_code} 日K数据时出错: {e}") + return pd.DataFrame() + +def create_stock_table(db: MySQLHelper, table_name: str) -> bool: + """ + 创建股票数据表 + + 参数: + db: 数据库连接 + table_name: 表名 (格式: 交易所_股票代码, 如 sh_600000) + + 返回: + bool: 是否成功 + """ + # 检查表名是否合法 + if not re.match(r"^[a-z]{2}_[0-9]{6}$", table_name): + print(f"表名 '{table_name}' 不符合命名规则") + return False + + # 创建表SQL + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS `{table_name}` ( + `date` DATE NOT NULL COMMENT '日期', + `code` DECIMAL(10, 2) NOT NULL COMMENT '代码', + `open` DECIMAL(10, 2) NOT NULL COMMENT '开盘价', + `close` DECIMAL(10, 2) NOT NULL COMMENT '收盘价', + `high` DECIMAL(10, 2) NOT NULL COMMENT '最高价', + `low` DECIMAL(10, 2) NOT NULL COMMENT '最低价', + `volume` BIGINT NOT NULL COMMENT '成交量(手)', + `amount` DECIMAL(20, 2) NOT NULL COMMENT '成交额(元)', + `amplitude` DECIMAL(5, 2) NOT NULL COMMENT '振幅(%)', + `change_percent` DECIMAL(5, 2) NOT NULL COMMENT '涨跌幅(%)', + `change_amount` DECIMAL(5, 2) NOT NULL COMMENT '涨跌额(元)', + `turnover` DECIMAL(5, 2) NOT NULL COMMENT '换手率(%)', + PRIMARY KEY (`date`), + INDEX `idx_date` (`date`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票日K数据表'; + """ + + try: + db.execute_update(create_table_sql) + return True + except Exception as e: + print(f"创建表 {table_name} 失败: {e}") + return False + +def save_stock_data_to_db(db: MySQLHelper, df: pd.DataFrame, table_name: str) -> int: + """ + 将股票数据保存到数据库表中 + + 参数: + db: 数据库连接 + df: 包含股票数据的DataFrame + table_name: 表名 (格式: 交易所_股票代码, 如 sh_600000) + + 返回: + int: 成功插入的记录数 + """ + if df.empty: + return 0 + + # 检查表名是否合法 + if not re.match(r"^[a-z]{2}_[0-9]{6}$", table_name): + print(f"表名 '{table_name}' 不符合命名规则") + return 0 + + # 准备插入SQL + insert_sql = f""" + INSERT INTO `{table_name}` ( + date, code, open, close, high, low, + volume, amount, amplitude, change_percent, + change_amount, turnover + ) VALUES ( + %s, %s, %s, %s, %s, + %s, %s, %s, %s, + %s, %s, %s + ) ON DUPLICATE KEY UPDATE + code = VALUES(code), + open = VALUES(open), + close = VALUES(close), + high = VALUES(high), + low = VALUES(low), + volume = VALUES(volume), + amount = VALUES(amount), + amplitude = VALUES(amplitude), + change_percent = VALUES(change_percent), + change_amount = VALUES(change_amount), + turnover = VALUES(turnover) + """ + + # 准备数据 + data_to_insert = [] + for _, row in df.iterrows(): + # 确保日期格式正确 + date_value = row['date'] + # if len(date_str) == 10: # YYYY-MM-DD + # date_value = date_str + # else: + # try: + # date_value = datetime.strptime(date_str, '%Y-%m-%d').strftime('%Y-%m-%d') + # except: + # # 尝试其他日期格式 + # date_value = date_str[:10] # 取前10个字符 + + data_to_insert.append(( + date_value, row['code'], row['open'], row['close'], + row['high'], row['low'], row['volume'], row['amount'], + row['amplitude'], row['change_percent'], + row['change_amount'], row['turnover'] + )) + + # 批量插入 + if data_to_insert: + try: + affected_rows = db.execute_many(insert_sql, data_to_insert) + print(f"表 {table_name}: 成功插入/更新 {affected_rows} 条记录") + return affected_rows + except Exception as e: + print(f"保存数据到表 {table_name} 失败: {e}") + return 0 + return 0 + +def generate_table_name(stock_code: str) -> str: + """ + 根据股票代码生成表名 (格式: 交易所_股票代码) + + 参数: + stock_code: 股票代码 (带或不带交易所前缀) + + 返回: + str: 表名 (如 sh_600000) + """ + if stock_code.startswith('6'): + return f"sh_{stock_code}" + elif stock_code.startswith(('0', '3')): + return f"sz_{stock_code}" + elif stock_code.startswith(('4', '8')): + return f"bj_{stock_code}" + + # 默认处理 + return f"unknown_{stock_code}" + +if __name__ == "__main__": + + # 读取股票代码 + logger.info("读取股票代码") + sh_stock_codes_context = get_SH_stock_codes_with_context() + sz_stock_codes_context = get_SZ_stock_codes_with_context() + all_stock_codes = sh_stock_codes_context + sz_stock_codes_context + + if all_stock_codes: + logger.info(f"前五个代码:{all_stock_codes[:5]}") + logger.info(f"后五个代码:{all_stock_codes[-6:-1]}") + + # 存储日K数据 + db_1d = MySQLHelper(**DB_CONFIG_1D) + if not db_1d.connect(): + logger.error("数据库连接失败") + + # 获取最近3年的数据 + start_date = (datetime.now() - timedelta(days = 3 * 365)).strftime("%Y%m%d") + end_date = (datetime.now() + timedelta(days = 1)).strftime("%Y%m%d") + logger.info(f"获取数据时间范围: {start_date} 至 {end_date}") + + nCount = 0 + for code in all_stock_codes: + nCount = nCount+1 + if nCount < 1584: + continue + df = get_daily_k_data(code,start_date,end_date) + + # 生成表名 (交易所_股票代码) + table_name = generate_table_name(code) + + # 创建表(如果不存在) + if not create_stock_table(db_1d, table_name): + logger.error(f"无法为股票 {code} 创建表 {table_name}") + + # 保存数据到表 + save_stock_data_to_db(db_1d, df, table_name) + + # 添加延迟,避免请求过快 + time.sleep(5) \ No newline at end of file