LLM2CSQL

thought

看完上述文章的思考,如何将LLM2SQL的思路迁移到LLM2CYPHER上?

LLM2CYPHER

  1. 用小模型来将自然语言转换为规范的查询自然语言增加转换准确性

查询目标:n1,n2,n3 //查询中的实体
查询关系:n1-r1-n2 ….. //查询中实体的关系
返回结果:result //查询语句中的结果

例如:小丽的朋友的朋友有哪些

  • 查询目标:小丽,小丽的朋友,小丽的朋友的朋友
  • 查询关系:小丽-朋友-小丽的朋友,小丽的朋友-朋友-小丽的朋友的朋友
  • 返回结果:小丽的朋友的朋友
  1. 多专家评判机制
    使用多个大模型专家进行nl2cypher,每个专家生成的cypher和原先的自然语言一起交给标号i+1的大模型专家,让其评判正确率,并且自己在根据自然语言生成cypher,与前一个专家转交的cypher进行比对,给予每个大模型专家不同的权重,当每个大模型一致认为正确时,转换完成

  2. 自适应学习机制
    添加自适应学习机制,并且给与不同专家不同的权重,例如某个专业侧重于整体结构,或者是数据类型等

  3. 物化视图缓存机制
    给上述的方案增添一种类似物化视图的结构,用来保存经过小模型转换的标准自然语言转换为的认为正确的cypher。在每次nltocypher的时候:

  • 先经过小模型的转换
  • 去查看物化视图中是否有保存
  • 若有则可以直接使用,或者取出让大模型专家进行评判
  • 若正确率低则需重新生成,否则则认为正确

可以对自然语言进行模糊匹配,找出结构相同的自然语言。比如”有多少奥迪车辆”和”有多少宝马车辆”很类似,可以使用前面的缓存交给专门进行标签匹配的大模型专家进行修改。

  1. 强化学习优化
    使用强化学习来自适应改变大语言模型专家的权重,或者根据执行结果反馈给专家,让其调整权重。

  2. 标签处理专家
    cypher的标签很重要,可以使用不同的大模型处理不同的部分,处理标签,具体的参数。

Rewrite Cypher

  1. 因为cypher的连接顺序会十分影响查询效率,所以需要对cypher进行重写,例如:
  • 原cypher:MATCH (n1)-[r1]->(n2)-[r2]->(n3) RETURN n1,n2,n3
  • 重写cypher:MATCH (n3)<-[r2]-(n2)<-[r1]-(n1) RETURN n1,n2,n3
  1. 如何判断cypher的连接顺序是否合理?
  • 使用图数据库的执行引擎,执行cypher,并记录执行时间,选择执行时间最短的连接顺序/问题在于需要执行cypher,才能知道执行时间,不能在转换阶段就进行判断
  • 使用强化学习预训练,在转换阶段就进行判断
  • 使用启发式算法,根据cypher的结构,判断连接顺序是否合理。启发式算法是一种在合理时间内找到较好解决方案的方法,它不一定能找到最优解,但可以找到接近最优解的答案。在cypher重写中,可以考虑以下启发式规则:
    1. 选择性原则:优先处理可以快速缩小结果集的条件,如带有索引或约束的节点
    2. 局部性原则:尽量将相关的操作放在一起处理,减少数据在内存中的移动
    3. 基数估计:根据节点和关系的数量来估计中间结果的大小,选择产生较小中间结果的执行顺序
    4. 模式匹配:识别常见的查询模式并应用已知的优化策略
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      class CypherRewriter:
      def __init__(self, graph_stats):
      # 图数据库统计信息,包含节点标签和关系类型的基数信息
      self.graph_stats = graph_stats

      def estimate_selectivity(self, pattern):
      """估算模式的选择性(越小越好)"""
      node_label = pattern.get('label', '')
      relationship = pattern.get('relationship', '')

      # 从统计信息中获取基数
      node_cardinality = self.graph_stats['nodes'].get(node_label, float('inf'))
      rel_cardinality = self.graph_stats['relationships'].get(relationship, float('inf'))

      # 计算选择性分数(越小表示越具有选择性)
      return (node_cardinality + rel_cardinality) / 2

      def rewrite_cypher(self, cypher_query):
      """重写Cypher查询以优化执行顺序"""
      # 解析查询模式
      patterns = self._parse_patterns(cypher_query)

      # 计算每个模式的选择性
      pattern_scores = [
      (pattern, self.estimate_selectivity(pattern))
      for pattern in patterns
      ]

      # 根据选择性排序(选择性高的优先)
      sorted_patterns = sorted(pattern_scores, key=lambda x: x[1])

      # 重建优化后的查询
      return self._rebuild_query(sorted_patterns, cypher_query)

      def _parse_patterns(self, cypher_query):
      """解析Cypher查询中的模式
      示例: MATCH (n:Person)-[r:KNOWS]->(m:Person)
      会被解析为多个模式组件
      """
      # 这里需要实现具体的解析逻辑
      # 返回解析后的模式列表
      patterns = []
      # ... 解析逻辑 ...
      return patterns

      def _rebuild_query(self, sorted_patterns, original_query):
      """根据排序后的模式重建查询"""
      # 重建优化后的查询字符串
      # ... 重建逻辑 ...
      return optimized_query

      # 使用示例
      graph_stats = {
      'nodes': {
      'Person': 1000,
      'Company': 100,
      },
      'relationships': {
      'KNOWS': 5000,
      'WORKS_AT': 1000,
      }
      }

      rewriter = CypherRewriter(graph_stats)
      query = "MATCH (n:Person)-[r:KNOWS]->(m:Person)-[w:WORKS_AT]->(c:Company) RETURN n,m,c"
      optimized_query = rewriter.rewrite_cypher(query)

将自动补全技术融入到LLM2SQL中提高可信度

在转换阶段生成了类似的模板之后,可以使用自动补全的算法来判断是否数据库中含有相应的结构,如果没有那么它肯定是错误的,若有也不能说是正确的,但是可信度会提高很多。
当然因为会提供给大模型Scheme,所以不需要真的使用数据库来进行自动补全,只需要根据Scheme进行判断即可。

PPT

ZNL2SQL

report

CODE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#ppt中的多级匹配函数实现
def value_match(db, question, gold_sql, pred_sql, table, column, pred_value, match_method, k):
item = Item(db, table, column, pred_value)
pred_value = str(pred_value)

with sqlite3.connect(item.db) as conn:
conn.text_factory = lambda x: str(x, "utf8", "ignore")
score, candidate = search_in_column(conn, item, match_method, k)
max_score = score[0]

if str(candidate[0]).strip().lower() == pred_value.strip().lower():
return table + '.' + column, None

table_score, table_column, table_candidate = search_in_table(conn, item, "fuzzy", k)
if len(table_score) > 0:
if str(table_candidate[0]).strip().lower() == pred_value.strip().lower():
return table + '.' + table_column[0], table_candidate[0]

if max_score > 0.65:
return table + '.' + column, number_check(pred_value, candidate[0])

table_score, table_column, table_candidate = search_in_table(conn, item, match_method, k)
if len(table_score) > 0:
max_score = table_score[0]
if max_score > 0.65:
if number_check(pred_value, table_candidate[0]) is None:
return table + '.' + column, None
return table + '.' + table_column[0], table_candidate[0]

database_score, database_table, database_column, database_candidate = search_in_database(conn, item, match_method, k)
if len(database_score) > 0:
max_score = database_score[0]
if max_score > 0.65:
if number_check(pred_value, database_candidate[0]) is None:
return table + '.' + column, None
return database_table[0] + '.' + database_column[0], database_candidate[0]

return table + '.' + column, number_check(pred_value, candidate[0])

pdf report

  • Report 更新时间:2024-11-28

TODO LIST

  1. 继续调研相关的文献,看看有没有前人做过的研究。
  2. 调研相关技术例如自动补全,启发式算法等
  3. 简单做一下各个想法的系统框图
  4. 查看AI4ALL的报告,学习微调技术