update exportdata

This commit is contained in:
2025-09-03 21:41:11 +08:00
parent 41beb2ec33
commit 7d35766a09
13 changed files with 734 additions and 91 deletions

View File

@@ -143,10 +143,54 @@ class DataExporter:
self.logger.error("无法获取月度均价数据")
return False
# 读取港股通标记
hk_inout_data = self.get_hk_inout()
for item in monthly_data:
stock_name = item.get('stock_code')[3:]
if stock_name in hk_inout_data:
item['in_out'] = 1
else:
item['in_out'] = 0
# 导出结果
file_path = 'data/' + csv_file if csv_file else None
csv_success = True
if csv_file:
csv_success = self.export_to_csv(monthly_data, file_path)
return csv_success
return csv_success
def get_hk_inout(self) -> Optional[List[Dict]]:
"""
从conditionalselection表读取流通股本数据
Args:
table_name: 源数据表名
Returns:
List[Dict]: 查询结果数据集失败返回None
"""
try:
with MySQLHelper(**self.db_config) as db:
# 查询流通股本数据
data = db.execute_query(f"""
SELECT stock_code
FROM hk_stock_connect
WHERE in_out = '1'
ORDER BY stock_code
""")
if not data:
self.logger.error(f"获取数据失败")
return None
return [
row['stock_code']
for row in data
if row['stock_code']
]
except Exception as e:
self.logger.error(f"从数据库读取流通股本数据失败: {str(e)}")
return None

View File

@@ -16,7 +16,8 @@ from typing import Optional, List, Dict, Union, Tuple
import csv
from typing import List, Dict, Optional
from datetime import datetime
from base import LogHelper, MySQLHelper, Config
from base import LogHelper, MySQLHelper, ConfigInfo
class MarketDataCalculator:
"""
@@ -41,8 +42,8 @@ class MarketDataCalculator:
"""
self.db_config = db_config
self.logger = LogHelper(logger_name=logger_name).setup()
self.month_ranges = Config.ConfigInfo.MONTH_RANGES
self.head_map = Config.ConfigInfo.HEADER_MAP
self.month_ranges = ConfigInfo.MONTH_RANGES
self.head_map = ConfigInfo.HEADER_MAP
def create_monthly_avg_table(self, target_table: str = "monthly_close_avg") -> bool:
"""
@@ -69,6 +70,7 @@ class MarketDataCalculator:
ym_2506 DECIMAL(20, 5) COMMENT '2025年06月',
ym_2507 DECIMAL(20, 5) COMMENT '2025年07月',
ym_2508 DECIMAL(20, 5) COMMENT '2025年08月',
ym_2509 DECIMAL(20, 5) COMMENT '2025年09月',
avg_all DECIMAL(20, 5) COMMENT '月间均值',
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
UNIQUE KEY uk_stock_code (stock_code)
@@ -82,7 +84,7 @@ class MarketDataCalculator:
return False
def calculate_and_save_monthly_avg(self,
source_table: str = "stock_quotes",
stock_code: str = "code",
target_table: str = "monthly_close_avg") -> bool:
"""
计算并保存2024年10月至2025年8月的月均流通市值
@@ -100,79 +102,85 @@ class MarketDataCalculator:
return False
with MySQLHelper(**self.db_config) as db:
# 获取所有股票代码和名称
stock_info = db.execute_query(
f"SELECT DISTINCT stock_code, stock_name FROM {source_table}"
)
if not stock_info:
self.logger.error("没有获取到股票基本信息")
return False
# 为每只股票计算各月均值
for stock in stock_info:
stock_code = stock['stock_code']
stock_name = stock['stock_name']
monthly_data = {'stock_code': stock_code, 'stock_name': stock_name}
# 计算每个月的均值
for month_col, (start_date, end_date) in self.month_ranges.items():
sql = """
SELECT AVG(close_price * float_share) as avg_close
FROM {}
WHERE stock_code = %s
AND trade_date BETWEEN %s AND %s
AND close_price IS NOT NULL
AND float_share IS NOT NULL
""".format(source_table)
result = db.execute_query(sql, (stock_code, start_date, end_date))
# 保存小数点后两位,以亿为单位
# monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None
monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None
# 提取所有以 'ym_' 开头的键的值
ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')]
valid_ym_values = [value for value in ym_values if value is not None]
# 计算全部月的均值
if valid_ym_values:
average = sum(valid_ym_values) / len(valid_ym_values)
monthly_data['avg_all'] = average
self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}")
else:
monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错
self.logger.warning(f"股票 {stock_code} 没有有效的月度数据")
# 插入或更新数据
upsert_sql = f"""
INSERT INTO {target_table} (
stock_code, stock_name,
ym_2501, ym_2502, ym_2503, ym_2504,
ym_2505, ym_2506,ym_2507, ym_2508,
avg_all
) VALUES (
%(stock_code)s, %(stock_name)s,
%(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s,
%(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s,
%(avg_all)s
)
ON DUPLICATE KEY UPDATE
stock_name = VALUES(stock_name),
ym_2501 = VALUES(ym_2501),
ym_2502 = VALUES(ym_2502),
ym_2503 = VALUES(ym_2503),
ym_2504 = VALUES(ym_2504),
ym_2505 = VALUES(ym_2505),
ym_2506 = VALUES(ym_2506),
ym_2507 = VALUES(ym_2507),
ym_2508 = VALUES(ym_2508),
avg_all = VALUES(avg_all),
update_time = CURRENT_TIMESTAMP
sql = """
SELECT stock_code, stock_name
FROM stock_filter
WHERE stock_code = %s
"""
db.execute_update(upsert_sql, monthly_data)
# stock_filter表格可以当成标准表
stock_info = db.execute_query(sql,(stock_code))
if len(stock_info) == 0:
return
stock_code = stock_info[0]['stock_code']
stock_name = stock_info[0]['stock_name']
monthly_data = {'stock_code': stock_code, 'stock_name': stock_name}
# 计算每个月的均值
source_table = 'hk_' + stock_code[3:]
for month_col, (start_date, end_date) in self.month_ranges.items():
sql = """
SELECT AVG(close_price * float_share) as avg_close
FROM {}
WHERE stock_code = %s
AND trade_date BETWEEN %s AND %s
AND close_price IS NOT NULL
AND float_share IS NOT NULL
""".format(source_table)
result = db.execute_query(sql, (stock_code, start_date, end_date))
# 保存小数点后两位,以亿为单位
# monthly_data[month_col] = float(result[0]['avg_close']) * 1000 if result and result[0]['avg_close'] else None
monthly_data[month_col] = round(float(result[0]['avg_close']) * 1000 / 100000000, 3) if result and result[0]['avg_close'] else None
# 提取所有以 'ym_' 开头的键的值
ym_values = [value for key, value in monthly_data.items() if key.startswith('ym_')]
valid_ym_values = [value for value in ym_values if value is not None]
# 计算全部月的均值
if valid_ym_values:
average = sum(valid_ym_values) / len(valid_ym_values)
monthly_data['avg_all'] = average
self.logger.debug(f"股票 {stock_code} 月间流通市值平均值为: {average}")
else:
monthly_data['avg_all'] = 0 # 给一个空值,保证数据库不报错
self.logger.warning(f"股票 {stock_code} 没有有效的月度数据")
# 插入或更新数据
upsert_sql = f"""
INSERT INTO {target_table} (
stock_code, stock_name,
ym_2501, ym_2502, ym_2503, ym_2504,
ym_2505, ym_2506,ym_2507, ym_2508,
ym_2509,
avg_all
) VALUES (
%(stock_code)s, %(stock_name)s,
%(ym_2501)s, %(ym_2502)s, %(ym_2503)s, %(ym_2504)s,
%(ym_2505)s, %(ym_2506)s, %(ym_2507)s, %(ym_2508)s,
%(ym_2509)s,
%(avg_all)s
)
ON DUPLICATE KEY UPDATE
stock_name = VALUES(stock_name),
ym_2501 = VALUES(ym_2501),
ym_2502 = VALUES(ym_2502),
ym_2503 = VALUES(ym_2503),
ym_2504 = VALUES(ym_2504),
ym_2505 = VALUES(ym_2505),
ym_2506 = VALUES(ym_2506),
ym_2507 = VALUES(ym_2507),
ym_2508 = VALUES(ym_2508),
ym_2509 = VALUES(ym_2509),
avg_all = VALUES(avg_all),
update_time = CURRENT_TIMESTAMP
"""
db.execute_update(upsert_sql, monthly_data)
# self.logger.info("月度均值计算和保存完成")
return True
except Exception as e:
@@ -435,4 +443,6 @@ class MarketDataCalculator:
if csv_file:
csv_success = self.export_to_csv(monthly_data, file_path)
return csv_success
return csv_success

View File

@@ -193,7 +193,7 @@ class KLineUpdater:
# 预处理数据
processed_data = self.preprocess_quote_data(quote_data, float_share)
if not processed_data:
self.logger.error("没有有效数据需要保存")
self.logger.error(f"没有有效数据需要保存,表:{table_name}")
return False
# 动态生成SQL插入语句
@@ -252,7 +252,7 @@ class KLineUpdater:
self.logger.info(f"创建了新表: {table_name}")
affected_rows = db.execute_many(insert_sql, processed_data)
self.logger.info(f"成功插入/更新 {affected_rows} 条行情记录到表 {table_name}")
# self.logger.info(f"成功插入/更新 {affected_rows} 条行情记录到表 {table_name}")
return True
except Exception as e:
self.logger.error(f"保存行情数据到表 {table_name} 失败: {str(e)}")

View File

@@ -23,7 +23,9 @@ class ConfigInfo:
'ym_2506': '2025年06月',
'ym_2507': '2025年07月',
'ym_2508': '2025年08月',
'avg_all': '度均值'
'ym_2509': '2025年09',
'avg_all': '月度均值',
'in_out':'是否在港股通'
}
# 月份范围配置
@@ -35,6 +37,7 @@ class ConfigInfo:
'ym_2505': ('2025-05-01', '2025-05-31'),
'ym_2506': ('2025-06-01', '2025-06-30'),
'ym_2507': ('2025-07-01', '2025-07-31'),
'ym_2508': ('2025-08-01', '2025-08-31')
'ym_2508': ('2025-08-01', '2025-08-31'),
'ym_2509': ('2025-09-01', '2025-09-30')
}

View File

@@ -9,7 +9,7 @@ import pymysql
from pymysql import Error
from typing import List, Dict, Union, Optional, Tuple
from contextlib import contextmanager
from base.LogHelper import LogHelper
from .LogHelper import LogHelper
# 基本用法(自动创建日期日志+控制台输出)
logger = LogHelper(logger_name = 'database').setup()
@@ -84,7 +84,9 @@ class MySQLHelper:
"""
try:
self.cursor.execute(sql, params)
return self.cursor.fetchall()
result = self.cursor.fetchall()
return result
except Error as e:
logger.error(f"查询执行失败: {e}")
return []
@@ -233,4 +235,4 @@ class MySQLHelper:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
self.close()

499
base/StockDataImporter.py Normal file
View File

@@ -0,0 +1,499 @@
"""
存储上海证券交易股票列表数据
不确定其数据爬取规则,防止 IP 被封
暂时使用该方案,获取股票列表数据
—— 下载excel,收到导入到数据库
"""
import pandas as pd
import os
import sys
import csv
import chardet # 用于检测文件编码
from pathlib import Path
from datetime import datetime
from .MySQLHelper import MySQLHelper
from .LogHelper import LogHelper
logger = LogHelper(logger_name = 'execelImport').setup()
class StockDataImporter:
"""股票数据导入工具支持CSV"""
def __init__(self, db_config: dict, column_mapping: dict, data_dir: Path):
self.db_config = db_config
self.column_mapping = column_mapping
self.data_dir = data_dir
self.df = None
self.csv_file = None
self.encoding = 'utf-8' # 默认编码
self.delimiter = ',' # 默认分隔符
self.upload_filename = None # 上传文件名
# 更新 检讨标志
def setUploadfile(self, filename: str):
"""设置需要上传的文件名"""
self.upload_filename = filename
logger.info(f"设置上传文件名为: {filename}")
def find_csv_file(self) -> Path:
"""在data文件夹中查找CSV文件"""
# 使用设置的upload_filename或默认文件名
filename = self.upload_filename if self.upload_filename else "GPLIST.csv"
# 查找所有匹配的文件
csv_files = list(self.data_dir.glob(filename))
if not csv_files:
logger.error(f"{self.data_dir} 中没有找到文件: {filename}")
return None
# 如果有多个文件,选择最新的
if len(csv_files) > 1:
csv_files.sort(key=os.path.getmtime, reverse=True)
logger.info(f"找到多个文件,选择最新的: {csv_files[0].name}")
return csv_files[0]
def validate_file(self, file_path: Path) -> bool:
"""验证CSV文件是否有效"""
try:
if not file_path.exists():
logger.error(f"CSV文件不存在: {file_path}")
return False
file_size = file_path.stat().st_size
if file_size == 0:
logger.error(f"CSV文件为空: {file_path}")
return False
return True
except Exception as e:
logger.error(f"文件验证失败: {e}")
return False
def detect_file_encoding(self, file_path: Path) -> str:
"""检测文件编码"""
try:
# 读取文件开头部分进行编码检测
with open(file_path, 'rb') as f:
raw_data = f.read(10000) # 读取前10KB
# 使用chardet检测编码
result = chardet.detect(raw_data)
encoding = result['encoding']
confidence = result['confidence']
# 常见编码替代
encoding_map = {
'GB2312': 'GBK',
'gb2312': 'GBK',
'ISO-8859-1': 'latin1',
'ascii': 'utf-8'
}
# 应用映射
encoding = encoding_map.get(encoding, encoding)
logger.info(f"检测到编码: {encoding} (置信度: {confidence:.2f})")
return encoding or 'utf-8'
except Exception as e:
logger.error(f"编码检测失败: {e}, 使用默认UTF-8")
return 'utf-8'
def detect_csv_delimiter(self, file_path: Path) -> str:
"""自动检测CSV分隔符"""
try:
# 使用检测到的编码打开文件
with open(file_path, 'r', encoding=self.encoding) as f:
# 读取前5行
lines = [f.readline() for _ in range(5) if f.readline()]
# 尝试常见分隔符
delimiters = [',', '\t', ';', '|']
delimiter_counts = {}
for delim in delimiters:
count = 0
for line in lines:
count += line.count(delim)
delimiter_counts[delim] = count
# 选择出现次数最多的分隔符
best_delim = max(delimiter_counts, key=delimiter_counts.get)
# 如果没有任何分隔符,则使用逗号
if delimiter_counts[best_delim] == 0:
logger.warning(f"无法检测到有效的分隔符,使用默认逗号分隔符")
return ','
logger.info(f"检测到分隔符: {repr(best_delim)}")
return best_delim
except Exception as e:
logger.error(f"检测分隔符失败: {e}, 使用默认逗号分隔符")
return ','
def read_csv_data(self, file_path: Path) -> bool:
"""从CSV文件读取数据"""
try:
# 1. 检测文件编码
self.encoding = self.detect_file_encoding(file_path)
# 2. 检测分隔符
self.delimiter = self.detect_csv_delimiter(file_path)
# 3. 读取CSV文件
logger.info(f"使用编码 '{self.encoding}' 和分隔符 '{self.delimiter}' 读取文件")
self.df = pd.read_csv(
file_path,
delimiter=self.delimiter,
dtype=str,
encoding=self.encoding,
on_bad_lines='warn',
quoting=csv.QUOTE_MINIMAL,
engine='python' # 更健壮的引擎
)
# 检查是否读取到数据
if self.df.empty:
logger.error("CSV文件没有包含有效数据")
return False
# 重命名列
self.df = self.df.rename(columns=self.column_mapping)
# 移除可能存在的空行
self.df = self.df.dropna(how='all')
logger.info(f"成功读取CSV数据{len(self.df)} 条记录")
return True
except UnicodeDecodeError:
# 尝试其他编码
encodings_to_try = ['GBK', 'latin1', 'ISO-8859-1', 'utf-16']
for enc in encodings_to_try:
try:
logger.warning(f"尝试使用 {enc} 编码读取文件")
self.df = pd.read_csv(
file_path,
delimiter=self.delimiter,
dtype=str,
encoding=enc
)
self.encoding = enc
logger.info(f"成功使用 {enc} 编码读取文件")
return True
except:
continue
logger.error("所有编码尝试均失败")
return False
except PermissionError:
logger.error(f"文件被占用,请关闭后重试: {file_path}")
return False
except Exception as e:
logger.error(f"读取CSV文件失败: {e}")
return False
def clean_stock_data(self) -> bool:
"""清洗股票数据(基本清洗,主要处理股票代码格式)"""
try:
# 验证股票代码格式如果存在stock_code列
if 'stock_code' in self.df.columns:
invalid_codes = self.df[~self.df['stock_code'].astype(str).str.match(r'^\d{6}$')]
if not invalid_codes.empty:
logger.warning(f"发现 {len(invalid_codes)} 条无效的股票代码")
logger.debug(f"无效代码示例: {invalid_codes['stock_code'].head().tolist()}")
logger.info("数据清洗完成")
return True
except Exception as e:
logger.error(f"数据清洗失败: {e}")
return False
def create_stocks_table(self, db: MySQLHelper) -> bool:
"""创建股票信息表(包含股票代码、中文名称、英文名称、进出标志和时间戳)"""
# 定义列类型映射
column_type_mapping = {
'stock_code': 'VARCHAR(6) PRIMARY KEY',
'stock_name_chn': 'VARCHAR(50) NULL',
'stock_name_en': 'VARCHAR(150)',
}
# 构建列定义SQL
column_definitions = []
for column_name in self.column_mapping.values():
if column_name in column_type_mapping:
column_definitions.append(f"{column_name} {column_type_mapping[column_name]}")
else:
# 对于未知列使用VARCHAR(255)
column_definitions.append(f"{column_name} VARCHAR(255)")
# 添加进出标志列
column_definitions.append("in_out TINYINT(1) DEFAULT 0 COMMENT '进出标志'")
# 添加时间戳列
column_definitions.append("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'")
column_definitions.append("updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间'")
# 构建完整的CREATE TABLE SQL
columns_sql = ",\n ".join(column_definitions)
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS hk_stock_connect (
{columns_sql}
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票信息表';
"""
try:
db.execute_update(create_table_sql)
logger.info("股票信息表创建成功")
return True
except Exception as e:
logger.error(f"创建表失败: {e}")
return False
def insert_data_to_db(self, db: MySQLHelper) -> bool:
"""将数据插入数据库(处理映射的列和进出标志)"""
if self.df is None or self.df.empty:
logger.error("没有有效数据可插入")
return False
# 获取所有映射的列名
mapped_columns = list(self.column_mapping.values())
# 构建INSERT SQL语句包含in_out列设置为1
columns_sql = ", ".join(mapped_columns + ['in_out'])
placeholders = ", ".join(["%s"] * len(mapped_columns) + ["1"]) # in_out固定为1
# 构建ON DUPLICATE KEY UPDATE部分主键不更新但更新in_out字段
update_clauses = []
for column in mapped_columns:
if column != 'stock_code': # 主键不更新
update_clauses.append(f"{column} = VALUES({column})")
# 添加in_out字段更新确保在重复时也设置为1
update_clauses.append("in_out = VALUES(in_out)")
update_sql = ", ".join(update_clauses)
insert_sql = f"""
INSERT INTO hk_stock_connect (
{columns_sql}
) VALUES (
{placeholders}
)
ON DUPLICATE KEY UPDATE
{update_sql}
"""
# 准备参数列表只包含映射的列in_out由SQL固定为1
params_list = []
for _, row in self.df.iterrows():
params = []
for column in mapped_columns:
# 处理可能的NaN值
value = row[column] if column in row and pd.notna(row[column]) else None
params.append(value)
params_list.append(tuple(params))
# 批量执行插入
try:
total_rows = len(params_list)
if total_rows == 0:
logger.error("没有有效数据可插入")
return False
batch_size = 1000 # 每批插入1000条记录
logger.info(f"开始插入数据,共 {total_rows} 条记录")
# 分批插入,避免大事务问题
for i in range(0, total_rows, batch_size):
batch_params = params_list[i:i+batch_size]
affected_rows = db.execute_many(insert_sql, batch_params)
logger.info(f"已处理 {min(i+batch_size, total_rows)}/{total_rows} 条记录")
logger.info(f"成功插入/更新 {total_rows} 条记录")
return True
except Exception as e:
logger.error(f"插入数据失败: {e}")
# 记录前5个参数以帮助调试
if params_list:
logger.debug(f"前5个参数示例: {params_list[:5]}")
return False
def setInOut(self, db: MySQLHelper) -> bool:
"""设置进出标志为0且不更新updated_at字段"""
try:
# # 首先检查表是否存在
# table_check_sql = """
# SELECT COUNT(*) AS table_exists
# FROM information_schema.tables
# WHERE table_schema = DATABASE()
# AND table_name = 'hk_stock_connect'
# """
# result = db.execute_query(table_check_sql)
# if not result or result[0]['table_exists'] == 0:
# logger.warning("表 'hk_stock_connect' 不存在,跳过设置进出标志操作")
# return True
# 表存在,执行更新操作
update_sql = "UPDATE hk_stock_connect SET in_out = 1, updated_at = updated_at"
affected_rows = db.execute_update(update_sql)
logger.info(f"成功设置 {affected_rows} 条记录的进出标志为0")
return True
except Exception as e:
logger.error(f"设置进出标志失败: {e}")
return False
def verify_data_in_db(self, db: MySQLHelper, sample_size: int = 5) -> bool:
"""验证数据库中的数据"""
try:
# 检查记录总数
count_sql = "SELECT COUNT(*) AS total FROM hk_stock_connect"
result = db.execute_query(count_sql)
db_count = result[0]['total'] if result else 0
logger.info(f"数据库中共有 {db_count} 条记录")
# 获取映射的列名用于显示
mapped_columns = list(self.column_mapping.values())
# 构建查询列使用所有映射的列和in_out字段
select_columns = mapped_columns + ['in_out'] if mapped_columns else ["*"]
columns_sql = ", ".join(select_columns)
# 随机抽样检查
sample_sql = f"""
SELECT {columns_sql}
FROM hk_stock_connect
ORDER BY RAND()
LIMIT {sample_size}
"""
samples = db.execute_query(sample_sql)
logger.info("\n随机抽样记录:")
for idx, sample in enumerate(samples, 1):
# 动态构建日志消息显示所有映射的列和in_out字段
sample_info = []
for column in select_columns:
if column in sample and sample[column] is not None:
sample_info.append(f"{column}: {sample[column]}")
logger.info(f"{idx}. {' | '.join(sample_info) if sample_info else 'No data to display'}")
return True
except Exception as e:
logger.error(f"数据验证失败: {e}")
return False
def run_import(self) -> bool:
"""执行完整的导入流程"""
logger.info(f"开始导入股票数据,数据目录: {self.data_dir}")
start_time = datetime.now()
# 1. 查找CSV文件
csv_file = self.find_csv_file()
if not csv_file:
return False
# 2. 验证文件
if not self.validate_file(csv_file):
return False
# 3. 读取CSV数据
if not self.read_csv_data(csv_file):
return False
# 4. 清洗数据
if not self.clean_stock_data():
return False
# 显示前5条数据动态显示可用列
logger.info("\n前5条股票数据:")
for i, row in self.df.head().iterrows():
# 动态构建显示信息
display_info = []
if 'stock_code' in row:
display_info.append(f"代码: {row['stock_code']}")
if 'stock_name_chn' in row:
display_info.append(f"名称: {row['stock_name_chn']}")
if 'stock_name_en' in row:
display_info.append(f"英文: {row['stock_name_en']}")
logger.info(f"{i+1}. {' | '.join(display_info)}")
# 5. 连接数据库并导入
try:
with MySQLHelper(**self.db_config) as db:
# 5.1 创建表
if not self.create_stocks_table(db):
return False
# 更新检讨结果
self.setInOut(db)
# 5.2 插入数据
if not self.insert_data_to_db(db):
return False
# 5.3 验证数据
if not self.verify_data_in_db(db):
return False
except Exception as e:
logger.error(f"数据库操作异常: {e}")
return False
# 计算执行时间
duration = datetime.now() - start_time
logger.info(f"数据处理成功完成! 总耗时: {duration.total_seconds():.2f}")
return True
if __name__ == "__main__":
# 数据库配置
db_config = {
'host': 'localhost',
'user': 'root',
'password': 'bzskmysql',
'database': 'hk_kline_1d'
}
# 列映射配置
COLUMN_MAPPING = {
'证券代码': 'stock_code',
'中文简称': 'stock_name_chn',
'英文简称': 'stock_name_en',
}
# 获取当前脚本所在目录
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
# 设置数据目录
DATA_DIR = current_dir / "data"
# 确保data目录存在
DATA_DIR.mkdir(exist_ok=True, parents=True)
# 安装依赖 (如果chardet未安装)
try:
import chardet
except ImportError:
logger.info("安装chardet库以支持编码检测...")
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "chardet"])
import chardet
# 创建导入器并执行导入(使用新的参数顺序)
importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR)
# 2.2w设置上传文件名(可以注释掉使用默认文件名)
importer.setUploadfile("港股通标的证券名单.csv")
if importer.run_import():
logger.info("股票数据导入成功!")
else:
logger.error("股票数据导入失败,请检查日志了解详情")

View File

@@ -1,3 +1,4 @@
from .MySQLHelper import MySQLHelper
from .LogHelper import LogHelper
from .Config import ConfigInfo
from .Config import ConfigInfo
from .StockDataImporter import StockDataImporter

View File

@@ -895,4 +895,8 @@ HK.08218
HK.06960
HK.02936
HK.03858
HK.08132
HK.08132
HK.02938
HK.02580
HK.02941
HK.02935

2
config/Removecode.txt Normal file
View File

@@ -0,0 +1,2 @@
HK.04335
HK.02292

View File

@@ -686,7 +686,6 @@ HK.02179
HK.00442
HK.01959
HK.01985
HK.02992
HK.00314
HK.01459
HK.01082

View File

@@ -0,0 +1 @@
HK.02938

View File

@@ -1,5 +1,5 @@
kevin_futu:
已使用 1000 行情
已使用 1000 行情 -> 移除 HK.02992
hang_futu:
已使用 703 行情
HK牛仔

View File

@@ -102,6 +102,9 @@ class MainWindow(QMainWindow):
# 创建功能按钮组
self.create_button_group(main_layout)
# 创建数据导入按钮组
self.create_import_button_group(main_layout)
# 创建进度条
self.progress_bar = QProgressBar()
# self.progress_bar.setVisible(False)
@@ -307,6 +310,20 @@ class MainWindow(QMainWindow):
button_group.setLayout(button_layout)
layout.addWidget(button_group)
def create_import_button_group(self, layout):
"""创建数据导入按钮组"""
import_group = QGroupBox("数据导入")
import_layout = QHBoxLayout()
# 导入数据按钮
self.btn_import = QPushButton('导入股票数据')
self.btn_import.clicked.connect(self.on_import_clicked)
self.btn_import.setToolTip('从CSV文件导入股票数据到数据库')
import_layout.addWidget(self.btn_import)
import_group.setLayout(import_layout)
layout.addWidget(import_group)
def create_log_area(self, layout):
"""创建日志显示区域"""
@@ -341,6 +358,8 @@ class MainWindow(QMainWindow):
self.btn_export.setEnabled(enabled)
self.btn_calculate.setEnabled(enabled)
self.btn_check.setEnabled(enabled)
self.btn_import.setEnabled(enabled)
self.btn_float_share.setEnabled(enabled)
def on_update_clicked(self):
"""更新数据按钮点击事件"""
@@ -401,6 +420,17 @@ class MainWindow(QMainWindow):
worker.finished_signal.connect(self.on_task_finished)
worker.start()
self.worker_threads.append(worker)
def on_import_clicked(self):
"""导入数据按钮点击事件"""
self.log_message("开始导入股票数据...")
self.set_buttons_enabled(False)
worker = WorkerThread(self.import_stock_data)
worker.log_signal.connect(self.log_message)
worker.finished_signal.connect(self.on_task_finished)
worker.start()
self.worker_threads.append(worker)
def on_task_finished(self, success, message):
"""任务完成回调"""
@@ -486,8 +516,9 @@ class MainWindow(QMainWindow):
# 移除人民币交易的股票股票名称最后一个字符为R误删除的从配置文件读回来
reserved_codes = calculator.read_stock_codes_list(Path.cwd() / "config" / "Reservedcode.txt")
remove_codes = calculator.read_stock_codes_list(Path.cwd() / "config" / "Removecode.txt")
market_data_ll = calculator.get_stock_codes() # 使用按照价格和流通股数量筛选的那个表格
market_data = market_data_ll + reserved_codes
market_data = market_data_ll + reserved_codes - remove_codes
# 根据统计时间进行命名
target_table_name = 'hk_monthly_avg_' + datetime.now().strftime("%Y%m%d")
@@ -496,11 +527,10 @@ class MainWindow(QMainWindow):
# 使用tqdm创建进度条
for code in tqdm(market_data, desc="处理股票数据", unit=""):
tablename = 'hk_' + code[3:]
# 计算并保存月度均值
calculator.calculate_and_save_monthly_avg(
source_table=tablename,
target_table=target_table_name
stock_code =code,
target_table = target_table_name
)
# self.log_message("月度平均计算完成")
@@ -529,6 +559,54 @@ class MainWindow(QMainWindow):
except Exception as e:
self.log_message(f"数据检查失败: {str(e)}")
raise
def import_stock_data(self):
"""导入股票数据任务"""
try:
# 导入必要的模块
from base.StockDataImporter import StockDataImporter
from base.MySQLHelper import MySQLHelper
from pathlib import Path
# 数据库配置
db_config = {
'host': 'localhost',
'user': 'root',
'password': 'bzskmysql',
'database': 'hk_kline_1d'
}
# 列映射配置
COLUMN_MAPPING = {
'证券代码': 'stock_code',
'中文简称': 'stock_name_chn',
'英文简称': 'stock_name_en',
}
# 设置数据目录
current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
DATA_DIR = current_dir / "data"
# 确保data目录存在
DATA_DIR.mkdir(exist_ok=True, parents=True)
# 创建导入器
importer = StockDataImporter(db_config, COLUMN_MAPPING, DATA_DIR)
# 设置上传文件名(使用默认文件名 "港股通标的证券名单.csv"
importer.setUploadfile("港股通标的证券名单.csv")
# 执行导入
if importer.run_import():
self.log_message("股票数据导入成功!")
return True
else:
self.log_message("股票数据导入失败,请检查日志了解详情")
return False
except Exception as e:
self.log_message(f"导入股票数据失败: {str(e)}")
raise
def closeEvent(self, event):
"""窗口关闭事件"""