|
A navigable small world graph. Image from NSWG |
A small world graph or network, is a mathematical graph with certain properties. In such a graph most nodes will not be neighbors but two neighbors of a node will likely be. In other words, traversing the graph from one node to the other can be achieved in very few edge transitions.
The idea is to use such a structure to support fast approximate nearest neighbor queries.
In a navigable small world graph (
NSWG) each node or vertex is associated with a vector in $\mathbb{R}^d$. During search, we check the distance to all neighbors of the node and transition to the closest one. We then repeat the search until we found a local minimum. In this post we discuss how to compute the $k$ closest neighbors and explore the method using the
UCR time series
dataset. The code can be found on
github.
Modelling the Graph and Search Results
The edges in the graph are implemented using a map. For each vertex, we map to the
neighboring vertices. The struct also holds the flattened vectors associated with each vertex.
/// A navigable small world graph defined
/// by a sparse edge set and a flat vector set, one for each vertex.
pub struct NavigableSmallWorldGraph {
pub edges: HashMap<usize, Vec<usize>>,
pub instances: Vec<f32>,
pub dim: usize
}
A search result is simply implemented as a struct with a node id as well as a distance
to the query.
pub struct SearchResult {
pub node: usize,
pub distance: f32
}
Furthermore, we implement comparing to search results by their distance. In
order to manage the set of candidates during knn search, we implement a
result set as a bineary tree set. In this way, extracting the top neighbors is
fast.
pub struct ResultStore {
ordered: BTreeSet<SearchResult>
}
impl ResultStore {
pub fn new() -> ResultStore {
ResultStore { ordered: BTreeSet::new() }
}
pub fn len(&self) -> usize {
self.ordered.len()
}
pub fn take(&self, k: usize) -> Vec<SearchResult> {
self.ordered.iter().take(k).map(|x| x.clone()).collect()
}
pub fn best(&self, at: usize) -> SearchResult {
let k_best = self.ordered.iter().take(at);
if k_best.len() < at {
SearchResult::none()
} else {
k_best.last().unwrap().clone()
}
}
pub fn insert(&mut self, result: SearchResult) {
self.ordered.insert(result);
}
pub fn remove(&mut self, result: &SearchResult) {
self.ordered.remove(result);
}
}
Insertion in this data structure automatically inserts in order.
The same goes for the remove option, which keeps the tree ordered.
Iterating the elements now allows to get the candidates sorted.
KNN-Search and Graph Construction
First we define the distances we support. For this post, I implement the euclidean distance between
a query vector and a node in the graph as well as the dynamic time warping distance, when seeing the
vectors as time series. If we give a warping band to the distance function, we compute the dtw distance.
fn distance(&self, query: &[f32], node: usize, align_band: Option) -> f32 {
match align_band {
Some(w) => self.dtw(query, node, w),
None => self.euclidean(query, node)
}
}
fn euclidean(&self, query: &[f32], node: usize) -> f32 {
assert!(query.len() == self.dim);
let y = &self.instances[node * self.dim .. (node + 1) * self.dim];
let mut distance = 0.0;
for i in 0 .. query.len() {
distance += f32::powf(query[i] - y[i], 2.0);
}
f32::sqrt(distance)
}
fn dtw(&self, query: &[f32], node: usize, w: usize) -> f32 {
let y = &self.instances[node * self.dim .. (node + 1) * self.dim];
let n = query.len();
let m = y.len();
// allocating here is the main bottleneck
let mut dp = vec![std::f32::INFINITY; (n + 1) * (m + 1)];
dp[0] = 0.0;
for i in 1 .. n + 1 {
for j in usize::max(NavigableSmallWorldGraph::sub(i, w), 1) .. usize::min(i + w, m + 1) {
let distance = f32::powf(query[i - 1] - y[j - 1], 2.0);
dp[i * (n + 1) + j] = distance + NavigableSmallWorldGraph::min3(
dp[(i - 1) * (n + 1) + j],
dp[i * (n + 1) + (j - 1)],
dp[(i - 1) * (n + 1) + (j - 1)]
)
}
}
dp[dp.len() - 1]
}
We can now turn to the search algorithm. During the search, we
maintain an ordered candidate lost, an ordered result list and an unordered
visited set. During a search run, we start with a random entry node.
We then add said node as the first candidate. In the following steps,
we get the best node from the candidate list and add all of it's neighbors
as candidates, as long as we did not visit the node before. If a candidate
is further away then the k-th neighbor so far, we stop.
We restart the search $n$ times from different random start points.
/// Given a query find the k-nearest neighbors using the nswg graph
/// * `query` The query vector
/// * `n_searches` number of restarts
/// * `k_neighbors` number of results
/// * `align_band` Sakoe Chiba Band for Dynamic Time Warping, if not set we use Euclidean Distance
pub fn search(
&self,
query: &[f32],
n_searches: usize,
k_neighbors: usize,
align_band: Option<usize>,
) -> (Vec<SearchResult>, usize) {
let mut candidates = ResultStore::new();
let mut results = ResultStore::new();
let mut visited = HashSet::new();
let mut rng = rand::thread_rng();
let mut n_steps = 0;
for _attempt in 0..n_searches {
let mut bsf = ResultStore::new();
// stop if we visited all of the nodes
if self.n_instances() == visited.len() {
break;
}
// sample entry point, make sure we never used it before
// and insert the point as a candidate
let mut entry_point = rng.gen_range(0, self.n_instances());
while visited.contains(&entry_point) {
entry_point = rng.gen_range(0, self.n_instances());
}
let distance = self.distance(query, entry_point, align_band);
candidates.insert(SearchResult::new(entry_point, distance));
while candidates.len() > 0 {
// check if we can stop: abendon search if the result is the worst than the top k so far
let candidate = candidates.best(1);
let kth = f32::min(
results.best(k_neighbors).distance,
bsf.best(k_neighbors).distance,
);
candidates.remove(&candidate);
if candidate.distance > kth {
break;
}
// add expand all edges
for edge in self.edges[&candidate.node].iter() {
if !visited.contains(edge) {
n_steps += 1;
visited.insert(edge);
let edge = *edge;
let distance = self.distance(query, edge, align_band);
candidates.insert(SearchResult::new(edge, distance));
bsf.insert(SearchResult::new(edge, distance));
}
}
}
for res in bsf.take(k_neighbors) {
results.insert(res);
}
}
(results.take(k_neighbors), n_steps)
}
Testing the Code on UCR Time Series
In order to test the code, we run it against multiple datasets from the UCR time series archive.
Our accuracy is similar to the UCR datasets. However, the number of distance computations is way less.