>在postgresql中查找数组之间的距离吗?

n3schb8v  于 2021-07-26  发布在  Java
关注(0)|答案(1)|浏览(450)

据我在本文中了解,在处理几何数据类型时,可以使用<->距离运算符查找最近邻:

  1. SELECT name, location --location is point
  2. FROM geonames
  3. ORDER BY location <-> '(29.9691,-95.6972)'
  4. LIMIT 5;

您还可以使用sp gist索引获得一些优化:

  1. CREATE INDEX idx_spgist_geonames_location ON geonames USING spgist(location);

但是我在文档中找不到任何关于对数组使用<->运算符的内容。如果我使用 double precision[] 而不是 point 举个例子,这样行吗?

xqk2d5yq

xqk2d5yq1#

显然,我们不能。例如,我有一个简单的表:

  1. CREATE TABLE test (
  2. id SERIAL PRIMARY KEY,
  3. loc double precision[]
  4. );

我想从中查询文档,按距离排序,

  1. SELECT loc FROM test ORDER BY loc <-> ARRAY[0, 0, 0, 0]::double precision[];

它不起作用:

  1. Query Error: error: operator does not exist: double precision[] <-> double precision[]

文档中也没有提到数组的<->。我在这个问题的公认答案中找到了一个解决方法,但它有一些限制,特别是在数组长度上。尽管有一篇文章(用俄语写)建议在数组大小限制方面采取一种变通方法。创建示例表:

  1. import postgresql
  2. def setup_db():
  3. db = postgresql.open('pq://user:pass@localhost:5434/db')
  4. db.execute("create extension if not exists cube;")
  5. db.execute("drop table if exists vectors")
  6. db.execute("create table vectors (id serial, file varchar, vec_low cube, vec_high cube);")
  7. db.execute("create index vectors_vec_idx on vectors (vec_low, vec_high);")

元素插入:

  1. query = "INSERT INTO vectors (file, vec_low, vec_high) VALUES ('{}', CUBE(array[{}]), CUBE(array[{}]))".format(
  2. file_name,
  3. ','.join(str(s) for s in encodings[0][0:64]),
  4. ','.join(str(s) for s in encodings[0][64:128]),
  5. )
  6. db.execute(query)

元素查询:

  1. import time
  2. import postgresql
  3. import random
  4. db = postgresql.open('pq://user:pass@localhost:5434/db')
  5. for i in range(100):
  6. t = time.time()
  7. encodings = [random.random() for i in range(128)]
  8. threshold = 0.6
  9. query = "SELECT file FROM vectors WHERE sqrt(power(CUBE(array[{}]) <-> vec_low, 2) + power(CUBE(array[{}]) <-> vec_high, 2)) <= {} ".format(
  10. ','.join(str(s) for s in encodings[0:64]),
  11. ','.join(str(s) for s in encodings[64:128]),
  12. threshold,
  13. ) + \
  14. "ORDER BY sqrt(power(CUBE(array[{}]) <-> vec_low, 2) + power(CUBE(array[{}]) <-> vec_high, 2)) ASC LIMIT 1".format(
  15. ','.join(str(s) for s in encodings[0:64]),
  16. ','.join(str(s) for s in encodings[64:128]),
  17. )
  18. print(db.query(query))
  19. print('inset time', time.time() - t, 'ind', i)
展开查看全部

相关问题