# Product Quantization for Similarity Search

## How to compress and fit a humongous set of vectors in memory for similarity search with asymmetric distance computation (ADC)

### [Click here to read and learn how Product Quantization works (with detailed explanation and illustrations)](https://peggy1502.medium.com/2f1f67c5fddd)


In [1]:
import numpy as np
from scipy.cluster.vq import kmeans2, vq
from scipy.spatial.distance import cdist

- M = number of segments
- k = number of centroids per segment
- s = dimension, or length of a segment

In [2]:
def PQ_train(vectors, M, k):
    s = int(vectors.shape[1] / M)                      # dimension (or length) of a segment.
    codebook = np.empty((M, k, s), np.float32)         
        
    for m in range(M):
        sub_vectors = vectors[:, m*s:(m+1)*s]          # sub-vectors for segment m.
        codebook[m], label = kmeans2(sub_vectors, k)   # run k-means clustering for each segment.
        
    return codebook        

In [3]:
def PQ_encode(vectors, codebook):
    M, k, s = codebook.shape
    PQ_code = np.empty((vectors.shape[0], M), np.uint8)
    
    for m in range(M):
        sub_vectors = vectors[:, m*s:(m+1)*s]           # sub-vectors for segment m.
        centroid_ids, _ = vq(sub_vectors, codebook[m])  # vq returns the nearest centroid Ids.
        PQ_code[:, m] = centroid_ids                    # assign centroid Ids to PQ_code.
        
    return PQ_code

In [4]:
def PQ_search(query_vector, codebook, PQ_code):
    M, k, s = codebook.shape
    #=====================================================================
    # Build the distance table.
    #=====================================================================
    
    distance_table = np.empty((M, k), np.float32)    # Shape is (M, k)    
        
    for m in range(M):
        query_segment = query_vector[m*s:(m+1)*s]    # query vector for segment m.
        distance_table[m] = cdist([query_segment], codebook[m], "sqeuclidean")[0]
        
    #=====================================================================
    # Look up the partial distances from the distance table.
    #=====================================================================
    
    N, M = PQ_code.shape
    distance_table = distance_table.T               # Transpose the distance table to shape (k, M)
    distances = np.zeros((N, )).astype(np.float32)

    for n in range(N):                              # For each PQ Code, lookup the partial distances.
        for m in range(M):
            distances[n] += distance_table[PQ_code[n][m]][m] # Sum the partial distances from all the segments.
            
    return distance_table, distances    

In [5]:
# def PQ_search(query_vector, codebook, PQ_code):
#     M, k, s = codebook.shape
#     distance_table = np.empty((M, k), np.float32)              # Shape is (M, k)   
    
#     for m in range(M):
#         query_segment = query_vector[m*s:(m+1)*s]               # query vector for segment m.
        
#         distance_table[m] = cdist([query_segment], codebook[m], "sqeuclidean")[0]
# #       distance_table[m] = np.linalg.norm(codebook[m] - query_segment, axis=1) ** 2
        
#     distances = np.sum(distance_table[range(M), PQ_code], axis=1)
    
#     return distances    

# Test Case 1

#### A small example with 10 database vectors (of length 6) that will be divided and split into 2 segments, with 4 centroids per segment.

#### You may use this example to verify and inspect the values for better understanding of Product Quantization.

In [6]:
M = 2
k = 4
vector_dim = 6          # Dimension of a vector
total_vectors = 10

np.random.seed(2022)
vectors = np.random.random((total_vectors, vector_dim)).astype(np.float32)   # Database vectors
q = np.random.random((vector_dim, )).astype(np.float32)                      # Query vector

codebook = PQ_train(vectors, M, k)
PQ_code = PQ_encode(vectors, codebook)
distance_table, distances = PQ_search(q, codebook, PQ_code)

In [7]:
vectors           # database vectors

array([[0.00935861, 0.4990578 , 0.11338369, 0.04997402, 0.6854076 ,
        0.48698807],
       [0.8976572 , 0.64745206, 0.8969631 , 0.7211349 , 0.8313534 ,
        0.82756805],
       [0.8335796 , 0.95704436, 0.36804444, 0.49483764, 0.3395095 ,
        0.61942935],
       [0.97752964, 0.09643308, 0.7442062 , 0.29249948, 0.29867536,
        0.7524735 ],
       [0.01866373, 0.52373743, 0.86443585, 0.38884285, 0.21219185,
        0.47518072],
       [0.5646724 , 0.3494293 , 0.97590864, 0.03782004, 0.7942697 ,
        0.3578826 ],
       [0.74796396, 0.9145093 , 0.37266243, 0.96488345, 0.08138578,
        0.04245099],
       [0.29679602, 0.36370364, 0.49025518, 0.6685187 , 0.67341465,
        0.57210064],
       [0.08059224, 0.8983313 , 0.03838853, 0.78219444, 0.03665636,
        0.26718384],
       [0.20522384, 0.25889444, 0.9326153 , 0.00812491, 0.40347317,
        0.8941022 ]], dtype=float32)

In [8]:
q                 # query vector

array([0.20420903, 0.02177601, 0.6971671 , 0.19102335, 0.546433  ,
       0.6032252 ], dtype=float32)

In [9]:
codebook          # the generated codebook

array([[[0.93759346, 0.37194258, 0.82058465],
        [0.01401117, 0.5113976 , 0.48890978],
        [0.3555641 , 0.32400915, 0.79959303],
        [0.55404526, 0.923295  , 0.25969848]],

       [[0.34058645, 0.5111673 , 0.8247146 ],
        [0.5173997 , 0.408372  , 0.55557024],
        [0.04389703, 0.7398386 , 0.42243534],
        [0.873539  , 0.05902107, 0.15481742]]], dtype=float32)

In [10]:
PQ_code           # the generated PQ codes

array([[1, 2],
       [0, 0],
       [3, 1],
       [0, 0],
       [1, 1],
       [2, 2],
       [3, 3],
       [2, 1],
       [3, 3],
       [2, 0]], dtype=uint8)

In [11]:
distance_table    # the generated distance table

array([[0.67570126, 0.07267036],
       [0.31927565, 0.12785336],
       [0.1247443 , 0.09173685],
       [1.1265007 , 0.90446746]], dtype=float32)

In [12]:
distances # the distances between the database vectors and the query vector.

array([0.4110125 , 0.7483716 , 1.2543541 , 0.7483716 , 0.447129  ,
       0.21648115, 2.0309682 , 0.25259766, 2.0309682 , 0.19741465],
      dtype=float32)

# Test Case 2
#### An example with 1 million database vectors (of length 128) that will be divided and split into 8 segments, with 256 centroids per segment.

In [13]:
M = 8
k = 256
vector_dim = 128          # Dimension (length) of a vector
total_vectors = 1000000

np.random.seed(2022)
vectors = np.random.random((total_vectors, vector_dim)).astype(np.float32)   # Database vectors
q = np.random.random((vector_dim, )).astype(np.float32)                      # Query vector

codebook = PQ_train(vectors, M, k)
PQ_code = PQ_encode(vectors, codebook)
distance_table, distances = PQ_search(q, codebook, PQ_code)



In [14]:
vectors           # database vectors

array([[0.00935861, 0.4990578 , 0.11338369, ..., 0.7015135 , 0.82271117,
        0.73850626],
       [0.987894  , 0.15918045, 0.9880797 , ..., 0.36329183, 0.6499846 ,
        0.6270492 ],
       [0.7823163 , 0.7571479 , 0.37924927, ..., 0.36728022, 0.5861753 ,
        0.02303002],
       ...,
       [0.9448289 , 0.49708685, 0.53720295, ..., 0.85491776, 0.79853326,
        0.3665858 ],
       [0.9057683 , 0.26638535, 0.50740963, ..., 0.4805671 , 0.5286727 ,
        0.65772986],
       [0.2985167 , 0.6905571 , 0.7515794 , ..., 0.8296632 , 0.6137684 ,
        0.14333938]], dtype=float32)

In [15]:
q                 # query vector

array([0.10070486, 0.1714809 , 0.15225852, 0.6886641 , 0.9749756 ,
       0.06685743, 0.69114894, 0.2268187 , 0.00141395, 0.13073783,
       0.92005795, 0.7369674 , 0.25174677, 0.9440858 , 0.05007194,
       0.7274404 , 0.8650318 , 0.56946415, 0.19805309, 0.75597805,
       0.69489574, 0.436116  , 0.30583853, 0.49263567, 0.44103155,
       0.93019086, 0.21397944, 0.26108417, 0.9023846 , 0.8116958 ,
       0.99607897, 0.14933279, 0.51558095, 0.08782236, 0.4078355 ,
       0.3299799 , 0.6854069 , 0.8505636 , 0.14264438, 0.19260886,
       0.6159144 , 0.2733055 , 0.9427627 , 0.98627466, 0.498921  ,
       0.08150337, 0.882082  , 0.27246374, 0.6357337 , 0.30563086,
       0.5093854 , 0.12601191, 0.5434625 , 0.21716632, 0.8092636 ,
       0.7351097 , 0.1922371 , 0.31089687, 0.24594605, 0.49831435,
       0.32414576, 0.43404552, 0.62408173, 0.83273077, 0.97268635,
       0.16446854, 0.45040593, 0.3236742 , 0.23752789, 0.08184742,
       0.8231972 , 0.7655861 , 0.82405686, 0.9156477 , 0.41033

In [16]:
codebook          # the generated codebook

array([[[0.2714486 , 0.24714273, 0.31333324, ..., 0.6957724 ,
         0.49512926, 0.3413709 ],
        [0.44175753, 0.71572846, 0.2816623 , ..., 0.24923094,
         0.7296848 , 0.5201776 ],
        [0.45408893, 0.28121132, 0.28604797, ..., 0.3837846 ,
         0.52248466, 0.33711308],
        ...,
        [0.6132575 , 0.7066365 , 0.77730024, ..., 0.61606586,
         0.2589899 , 0.74349356],
        [0.49634   , 0.688018  , 0.6851171 , ..., 0.7145026 ,
         0.30040693, 0.711934  ],
        [0.38210297, 0.25856316, 0.30580595, ..., 0.375045  ,
         0.39922574, 0.72658205]],

       [[0.2635966 , 0.72523916, 0.5818799 , ..., 0.68363416,
         0.27500525, 0.26784706],
        [0.74839205, 0.68505275, 0.44021732, ..., 0.26671767,
         0.71851015, 0.33116615],
        [0.26986226, 0.66971135, 0.75319296, ..., 0.46202248,
         0.3312493 , 0.2585347 ],
        ...,
        [0.2670113 , 0.26283354, 0.2643863 , ..., 0.2764622 ,
         0.54726934, 0.70552015],
        [0.3

In [17]:
PQ_code           # the generated PQ codes

array([[ 29, 212, 214, ..., 160,  46,  31],
       [ 80,  62,   6, ...,  53,  62, 187],
       [ 33,  82, 157, ..., 222, 207, 205],
       ...,
       [ 81, 173,  98, ..., 188, 232,  31],
       [116, 254, 188, ..., 181, 115,  82],
       [119, 231, 185, ...,  34,  40, 210]], dtype=uint8)

In [18]:
distance_table     # the generated distance table

array([[2.092305 , 2.4898663, 2.9422212, ..., 1.629852 , 2.654244 ,
        1.9011606],
       [2.088503 , 1.5805417, 2.0801227, ..., 1.937171 , 3.4474962,
        2.0638335],
       [2.9624765, 2.0734308, 1.4903691, ..., 1.8205851, 2.9466026,
        2.7976434],
       ...,
       [2.763594 , 2.2813995, 2.302858 , ..., 2.311877 , 2.6196272,
        1.862288 ],
       [2.5436997, 1.899764 , 2.10331  , ..., 1.3405595, 2.2057638,
        2.4736505],
       [1.9196714, 2.3986058, 2.3290567, ..., 2.6416855, 2.132494 ,
        1.0776821]], dtype=float32)

In [19]:
distances          # the distances between the database vectors and the query vector.

array([16.849852, 15.688269, 16.844519, ..., 16.050737, 17.386984,
       15.984062], dtype=float32)