From 34c49da26cbc5737b6d5fddb909002413935047d Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Fri, 10 Jan 2025 17:09:29 -0800 Subject: [PATCH 1/3] Initial pass, needs tests+docs --- dbg.sql | 21 ++++++++ sqlite-vec.c | 139 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 158 insertions(+), 2 deletions(-) create mode 100644 dbg.sql diff --git a/dbg.sql b/dbg.sql new file mode 100644 index 00000000..f5fab9c0 --- /dev/null +++ b/dbg.sql @@ -0,0 +1,21 @@ +.load dist/vec0 + + +create virtual table vec_items using vec0( + vector float[1] +); + +insert into vec_items(rowid, vector) + select value, json_array(value) from generate_series(1, 100); + + +select vec_to_json(vector), distance +from vec_items +where vector match '[1]' + and k = 5; + +select vec_to_json(vector), distance +from vec_items +where vector match '[1]' + and k = 5 + and distance > 4.0; \ No newline at end of file diff --git a/sqlite-vec.c b/sqlite-vec.c index 3cc802f0..982e5f88 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5305,11 +5305,21 @@ static int vec0Close(sqlite3_vtab_cursor *cur) { typedef enum { // If any values are updated, please update the ARCHITECTURE.md docs accordingly! + // ~~~ KNN QUERIES ~~~ // VEC0_IDXSTR_KIND_KNN_MATCH = '{', VEC0_IDXSTR_KIND_KNN_K = '}', VEC0_IDXSTR_KIND_KNN_ROWID_IN = '[', + // argv[i] is a constraint on a PARTITON KEY column in a KNN query + // VEC0_IDXSTR_KIND_KNN_PARTITON_CONSTRAINT = ']', + + // argv[i] is a constraint on the distance column in a KNN query + VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT = '*', + + // ~~~ POINT QUERIES ~~~ // VEC0_IDXSTR_KIND_POINT_ID = '!', + + // ~~~ ??? ~~~ // VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&', } vec0_idxstr_kind; @@ -5318,11 +5328,22 @@ typedef enum { typedef enum { // If any values are updated, please update the ARCHITECTURE.md docs accordingly! + // Equality constraint on a PARTITON KEY column, ex `user_id = 123` VEC0_PARTITION_OPERATOR_EQ = 'a', + + // "Greater than" constraint on a PARTITON KEY column, ex `year > 2024` VEC0_PARTITION_OPERATOR_GT = 'b', + + // "Less than or equal to" constraint on a PARTITON KEY column, ex `year <= 2024` VEC0_PARTITION_OPERATOR_LE = 'c', + + // "Less than" constraint on a PARTITON KEY column, ex `year < 2024` VEC0_PARTITION_OPERATOR_LT = 'd', + + // "Greater than or equal to" constraint on a PARTITON KEY column, ex `year >= 2024` VEC0_PARTITION_OPERATOR_GE = 'e', + + // "Not equal to" constraint on a PARTITON KEY column, ex `year != 2024` VEC0_PARTITION_OPERATOR_NE = 'f', } vec0_partition_operator; typedef enum { @@ -5335,6 +5356,15 @@ typedef enum { VEC0_METADATA_OPERATOR_IN = 'g', } vec0_metadata_operator; + +typedef enum { + + VEC0_DISTANCE_CONSTRAINT_GT = 'a', + VEC0_DISTANCE_CONSTRAINT_GE = 'b', + VEC0_DISTANCE_CONSTRAINT_LT = 'c', + VEC0_DISTANCE_CONSTRAINT_LE = 'd', +} vec0_distance_constraint_operator; + static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { vec0_vtab *p = (vec0_vtab *)pVTab; /** @@ -5494,6 +5524,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } #endif + // find any PARTITION KEY column constraints for (int i = 0; i < pIdxInfo->nConstraint; i++) { if (!pIdxInfo->aConstraint[i].usable) continue; @@ -5548,6 +5579,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } + // find any metadata column constraints for (int i = 0; i < pIdxInfo->nConstraint; i++) { if (!pIdxInfo->aConstraint[i].usable) continue; @@ -5644,6 +5676,58 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } + // find any distance column constraints + for (int i = 0; i < pIdxInfo->nConstraint; i++) { + if (!pIdxInfo->aConstraint[i].usable) + continue; + + int iColumn = pIdxInfo->aConstraint[i].iColumn; + int op = pIdxInfo->aConstraint[i].op; + if(op == SQLITE_INDEX_CONSTRAINT_LIMIT || op == SQLITE_INDEX_CONSTRAINT_OFFSET) { + continue; + } + if(vec0_column_distance_idx(p) != iColumn) { + continue; + } + + char value = 0; + switch(op) { + case SQLITE_INDEX_CONSTRAINT_GT: { + value = VEC0_DISTANCE_CONSTRAINT_GT; + break; + } + case SQLITE_INDEX_CONSTRAINT_GE: { + value = VEC0_DISTANCE_CONSTRAINT_GE; + break; + } + case SQLITE_INDEX_CONSTRAINT_LT: { + value = VEC0_DISTANCE_CONSTRAINT_LT; + break; + } + case SQLITE_INDEX_CONSTRAINT_LE: { + value = VEC0_DISTANCE_CONSTRAINT_LE; + break; + } + default: { + // IMP TODO + rc = SQLITE_ERROR; + vtab_set_error( + pVTab, + "Illegal WHERE constraint on distance column in a KNN query. " + "Only one of GT, GE, LT, LE constraints are allowed." + ); + goto done; + } + } + + pIdxInfo->aConstraintUsage[i].argvIndex = argvIndex++; + pIdxInfo->aConstraintUsage[i].omit = 1; + sqlite3_str_appendchar(idxStr, 1, VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT); + sqlite3_str_appendchar(idxStr, 1, value); + sqlite3_str_appendchar(idxStr, 1, '_'); + sqlite3_str_appendchar(idxStr, 1, '_'); + } + pIdxInfo->idxNum = iMatchVectorTerm; @@ -5672,7 +5756,6 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { } pIdxInfo->needToFreeIdxStr = 1; - rc = SQLITE_OK; done: @@ -6560,12 +6643,15 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, int numValueEntries = (idxStrLength-1) / 4; assert(numValueEntries == argc); int hasMetadataFilters = 0; + int hasDistanceConstraints = 0; for(int i = 0; i < argc; i++) { int idx = 1 + (i * 4); char kind = idxStr[idx + 0]; if(kind == VEC0_IDXSTR_KIND_METADATA_CONSTRAINT) { hasMetadataFilters = 1; - break; + } + else if(kind == VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT) { + hasDistanceConstraints = 1; } } @@ -6752,6 +6838,55 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks, chunk_distances[i] = result; } + if(hasDistanceConstraints) { + for(int i = 0; i < argc; i++) { + int idx = 1 + (i * 4); + char kind = idxStr[idx + 0]; + // TODO casts f64 to f32, is that a problem? + f32 target = (f32) sqlite3_value_double(argv[i]); + + if(kind != VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT) { + continue; + } + vec0_distance_constraint_operator op = idxStr[idx + 1]; + + switch(op) { + case VEC0_DISTANCE_CONSTRAINT_GE: { + for(int i = 0; i < p->chunk_size;i++) { + if(bitmap_get(b, i) && !(chunk_distances[i] >= target)) { + bitmap_set(b, i, 0); + } + } + break; + } + case VEC0_DISTANCE_CONSTRAINT_GT: { + for(int i = 0; i < p->chunk_size;i++) { + if(bitmap_get(b, i) && !(chunk_distances[i] > target)) { + bitmap_set(b, i, 0); + } + } + break; + } + case VEC0_DISTANCE_CONSTRAINT_LE: { + for(int i = 0; i < p->chunk_size;i++) { + if(bitmap_get(b, i) && !(chunk_distances[i] <= target)) { + bitmap_set(b, i, 0); + } + } + break; + } + case VEC0_DISTANCE_CONSTRAINT_LT: { + for(int i = 0; i < p->chunk_size;i++) { + if(bitmap_get(b, i) && !(chunk_distances[i] < target)) { + bitmap_set(b, i, 0); + } + } + break; + } + } + } + } + int used1; min_idx(chunk_distances, p->chunk_size, b, chunk_topk_idxs, min(k, p->chunk_size), bTaken, &used1); From bbb323820928c256aa30b9089c918ffceb600a96 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Thu, 12 Feb 2026 21:35:50 -0800 Subject: [PATCH 2/3] old: test-knn-constraints --- test.sql | 2 +- .../test-knn-distance-constraints.ambr | 273 ++++++++++++++++++ tests/test-knn-distance-constraints.py | 82 ++++++ 3 files changed, 356 insertions(+), 1 deletion(-) create mode 100644 tests/__snapshots__/test-knn-distance-constraints.ambr create mode 100644 tests/test-knn-distance-constraints.py diff --git a/test.sql b/test.sql index 9d615a7f..8cd3f30e 100644 --- a/test.sql +++ b/test.sql @@ -1,5 +1,5 @@ -.load dist/vec0main +.load dist/vec0 .bail on .mode qbox diff --git a/tests/__snapshots__/test-knn-distance-constraints.ambr b/tests/__snapshots__/test-knn-distance-constraints.ambr new file mode 100644 index 00000000..87695a1d --- /dev/null +++ b/tests/__snapshots__/test-knn-distance-constraints.ambr @@ -0,0 +1,273 @@ +# serializer version: 1 +# name: test_normal + OrderedDict({ + 'sql': 'SELECT * FROM v', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'embedding': b'\x00\x00\x80?', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 2, + 'embedding': b'\x00\x00\x00@', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 3, + 'embedding': b'\x00\x00@@', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 4, + 'embedding': b'\x00\x00\x80@', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 5, + 'embedding': b'\x00\x00\xa0@', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 6, + 'embedding': b'\x00\x00\xc0@', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 7, + 'embedding': b'\x00\x00\xe0@', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 8, + 'embedding': b'\x00\x00\x00A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 9, + 'embedding': b'\x00\x00\x10A', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 10, + 'embedding': b'\x00\x00 A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 11, + 'embedding': b'\x00\x000A', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 12, + 'embedding': b'\x00\x00@A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 13, + 'embedding': b'\x00\x00PA', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 14, + 'embedding': b'\x00\x00`A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 15, + 'embedding': b'\x00\x00pA', + 'is_odd': 1, + }), + OrderedDict({ + 'rowid': 16, + 'embedding': b'\x00\x00\x80A', + 'is_odd': 0, + }), + OrderedDict({ + 'rowid': 17, + 'embedding': b'\x00\x00\x88A', + 'is_odd': 1, + }), + ]), + }) +# --- +# name: test_normal.1 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? ', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'distance': 0.0, + }), + OrderedDict({ + 'rowid': 2, + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 3, + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 4, + 'distance': 3.0, + }), + OrderedDict({ + 'rowid': 5, + 'distance': 4.0, + }), + ]), + }) +# --- +# name: test_normal.2 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance > 5', + 'rows': list([ + OrderedDict({ + 'rowid': 7, + 'distance': 6.0, + }), + OrderedDict({ + 'rowid': 8, + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- +# name: test_normal.3 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance >= 5', + 'rows': list([ + OrderedDict({ + 'rowid': 6, + 'distance': 5.0, + }), + OrderedDict({ + 'rowid': 7, + 'distance': 6.0, + }), + OrderedDict({ + 'rowid': 8, + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + ]), + }) +# --- +# name: test_normal.4 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance < 3', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'distance': 0.0, + }), + OrderedDict({ + 'rowid': 2, + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 3, + 'distance': 2.0, + }), + ]), + }) +# --- +# name: test_normal.5 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance <= 3', + 'rows': list([ + OrderedDict({ + 'rowid': 1, + 'distance': 0.0, + }), + OrderedDict({ + 'rowid': 2, + 'distance': 1.0, + }), + OrderedDict({ + 'rowid': 3, + 'distance': 2.0, + }), + OrderedDict({ + 'rowid': 4, + 'distance': 3.0, + }), + ]), + }) +# --- +# name: test_normal.6 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance > 7 AND distance <= 10', + 'rows': list([ + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- +# name: test_normal.7 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND distance BETWEEN 7 AND 10', + 'rows': list([ + OrderedDict({ + 'rowid': 8, + 'distance': 7.0, + }), + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 10, + 'distance': 9.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- +# name: test_normal.8 + OrderedDict({ + 'sql': 'select rowid, distance from v where embedding match ? and k = ? AND is_odd == TRUE AND distance BETWEEN 7 AND 10', + 'rows': list([ + OrderedDict({ + 'rowid': 9, + 'distance': 8.0, + }), + OrderedDict({ + 'rowid': 11, + 'distance': 10.0, + }), + ]), + }) +# --- diff --git a/tests/test-knn-distance-constraints.py b/tests/test-knn-distance-constraints.py new file mode 100644 index 00000000..ed2d9ec9 --- /dev/null +++ b/tests/test-knn-distance-constraints.py @@ -0,0 +1,82 @@ +import sqlite3 +from collections import OrderedDict + + +def test_normal(db, snapshot): + db.execute("create virtual table v using vec0(embedding float[1], is_odd boolean, chunk_size=8)") + db.executemany( + "insert into v(rowid, is_odd, embedding) values (?1, ?1 % 2, ?2)", + [ + [1, "[1]"], + [2, "[2]"], + [3, "[3]"], + [4, "[4]"], + [5, "[5]"], + [6, "[6]"], + [7, "[7]"], + [8, "[8]"], + [9, "[9]"], + [10, "[10]"], + [11, "[11]"], + [12, "[12]"], + [13, "[13]"], + [14, "[14]"], + [15, "[15]"], + [16, "[16]"], + [17, "[17]"], + ], + ) + assert exec(db,"SELECT * FROM v") == snapshot() + + BASE_KNN = "select rowid, distance from v where embedding match ? and k = ? " + assert exec(db, BASE_KNN, ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance > 5", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance >= 5", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance < 3", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance <= 3", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance > 7 AND distance <= 10", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND distance BETWEEN 7 AND 10", ["[1]", 5]) == snapshot() + assert exec(db, BASE_KNN + "AND is_odd == TRUE AND distance BETWEEN 7 AND 10", ["[1]", 5]) == snapshot() + + +class Row: + def __init__(self): + pass + + def __repr__(self) -> str: + return repr() + + +def exec(db, sql, parameters=[]): + try: + rows = db.execute(sql, parameters).fetchall() + except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: + return { + "error": e.__class__.__name__, + "message": str(e), + } + a = [] + for row in rows: + o = OrderedDict() + for k in row.keys(): + o[k] = row[k] + a.append(o) + result = OrderedDict() + result["sql"] = sql + result["rows"] = a + return result + + +def vec0_shadow_table_contents(db, v): + shadow_tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like ? order by 1", [f"{v}_%"] + ).fetchall() + ] + o = {} + for shadow_table in shadow_tables: + if shadow_table.endswith("_info"): + continue + o[shadow_table] = exec(db, f"select * from {shadow_table}") + return o From 5e4226a2571dc6b683d2788e57f0e9f50bfcb101 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Fri, 13 Feb 2026 06:38:01 -0800 Subject: [PATCH 3/3] cleanup --- dbg.sql | 21 --------------------- test.sql | 2 +- 2 files changed, 1 insertion(+), 22 deletions(-) delete mode 100644 dbg.sql diff --git a/dbg.sql b/dbg.sql deleted file mode 100644 index f5fab9c0..00000000 --- a/dbg.sql +++ /dev/null @@ -1,21 +0,0 @@ -.load dist/vec0 - - -create virtual table vec_items using vec0( - vector float[1] -); - -insert into vec_items(rowid, vector) - select value, json_array(value) from generate_series(1, 100); - - -select vec_to_json(vector), distance -from vec_items -where vector match '[1]' - and k = 5; - -select vec_to_json(vector), distance -from vec_items -where vector match '[1]' - and k = 5 - and distance > 4.0; \ No newline at end of file diff --git a/test.sql b/test.sql index 8cd3f30e..9d615a7f 100644 --- a/test.sql +++ b/test.sql @@ -1,5 +1,5 @@ -.load dist/vec0 +.load dist/vec0main .bail on .mode qbox