Skip to content

Instantly share code, notes, and snippets.

@ChristianAlexander
Created August 6, 2023 22:05
Show Gist options
  • Save ChristianAlexander/30174b3f0f921a41d03bb9578e6bed7b to your computer and use it in GitHub Desktop.
Save ChristianAlexander/30174b3f0f921a41d03bb9578e6bed7b to your computer and use it in GitHub Desktop.
An Elixir LiveBook file showing an implementation of the Random Cut Forest Algorithm for anomaly detection.

Random Cut Forest

Mix.install([
  {:nx, "~> 0.5.3"},
  {:ex_zipper, "~> 0.1.3"},
  {:kino_vega_lite, "~> 0.1.9"}
])

Intro

Anomaly detection is a complicated and well-researched area of computer science.

In 2016, a paper called Robust Random Cut Forest Algorithm was published, serving as the inspiration for this notebook.

S. Guha, N. Mishra, G. Roy, & O. Schrijvers, Robust random cut forest based anomaly detection on streams, in Proceedings of the 33rd International conference on machine learning, New York, NY, 2016 (pp. 2712-2721).

The Random Cut Forest Algorithm is an unsupervised anomaly detection algorithm that is typically used in Amazon Sagemaker (docs). It aims to improve upon the Isolation Forest Algorithm by measuring the collusive displacement rather than just the depth of a leaf node.

A point's collusive displacement is a measure of how many leaves in the tree would be displaced as a result of its insertion or removal.

An ensemble of trees (a forest) provide collusive displacement values for their own trees, and the entire ensemble's arithmetic mean displacement is used as a result.

The goal of this notebook is to implement a proof-of-concept random cut forest to explore the concepts and gain a deeper understanding of the technique.

It is written entirely in Elixir, allowing it to be invoked without NIFs. However, a production-grade solution would likely involve calling a native implementation.

Naive Tree Implementation

What is the code doing?

The Tree module provides two key methods making it function as a rudimentary random cut tree.

  • insert_point: Inserts a point in the tree, making it referencable by an index value
  • codisp: Measures the collusive displacement of a point previously inserted into the tree, identified by its index value

This implementation does not include forget_point, which would remove a point from the tree. Without forget_point, stream maintenance methods described in the paper are not possible. To remove a specific point from the tree, a new tree must be constructed that omits the forgotten point.

On zippers

Data structures in Elixir are immutable. This makes tree construction and maintenance very tricky. Zippers are able to emulate trees by opening and closing nodes, producing a view after every operation. They must copy data during evaluation and are not as efficient as mutable structures found in mutable languages.

alias ExZipper.Zipper, as: Zipper

defmodule Leaf do
  defstruct [:point, :index, count: 1]
end

defmodule Branch do
  defstruct [
    :cut_dimension,
    :cut_value,
    :count,
    :bounding_box,
    left_child: nil,
    right_child: nil
  ]
end

defmodule Tree do
  defstruct [:root, dimensionality: nil, indexes: MapSet.new()]

  def new do
    root =
      Zipper.zipper(
        fn
          %Branch{} -> true
          _ -> false
        end,
        fn %Branch{left_child: l, right_child: r} ->
          Enum.filter([l, r], &Function.identity/1)
        end,
        fn
          %Branch{} = b, [l, r] ->
            %{b | left_child: l, right_child: r}

          %Branch{} = b, [l] ->
            %{b | left_child: l}

          n, _ ->
            n
        end,
        nil
      )

    %__MODULE__{root: root}
  end

  @doc """
  Performs the collusive displacement calculation
  """
  def codisp(tree, index) do
    if not MapSet.member?(tree.indexes, index) do
      {:error, :index_not_in_tree}
    else
      leaf_zipper = find(tree.root, &(Map.get(&1, :index) == index))
      codisp_upward(leaf_zipper)
    end
  end

  defp codisp_upward(zipper) do
    if root?(zipper) do
      0.0
    else
      [sibling] = siblings(zipper)
      n = Zipper.node(zipper)
      num_deleted = n.count
      displacement = sibling.count
      result = displacement / num_deleted

      max(result, codisp_upward(Zipper.up(zipper)))
    end
  end

  @doc """
  Inserts a point into the tree, accessible in the future by its index value.
  """
  def insert_point(tree, point, index) do
    point = Nx.flatten(point)

    if(not is_nil(tree.dimensionality) and Nx.size(point) != tree.dimensionality) do
      raise "New points must have dimensions matching previously-seen points"
    end

    cond do
      MapSet.member?(tree.indexes, index) ->
        root =
          tree.root
          |> find(&(Map.get(&1, :index) == index))
          |> Zipper.edit(fn n -> %{count: n.count + 1} end)
          |> map_up(fn node ->
            %{node | count: node.count + 1}
          end)
          |> Zipper.root()

        %{tree | root: root}

      # TODO: Find existing leaf node by similarity search

      is_nil(Zipper.node(tree.root)) ->
        leaf =
          Zipper.make_node(
            tree.root,
            %Leaf{point: point, index: index},
            nil
          )

        root = Zipper.replace(tree.root, leaf)

        %{
          tree
          | dimensionality: Nx.size(point),
            root: root,
            indexes: MapSet.put(tree.indexes, index)
        }

      true ->
        zipper =
          do_insertion(tree.root, point, index)
          |> map_up(fn node ->
            %{
              node
              | count: node.count + 1,
                bounding_box: extend_bounding_box(bounding_box(node), point)
            }
          end)

        %{tree | root: Zipper.root(zipper), indexes: MapSet.put(tree.indexes, index)}
    end
  end

  defp siblings(zipper) do
    Zipper.lefts(zipper) ++ Zipper.rights(zipper)
  end

  defp root?(zipper), do: is_nil(zipper.crumbs)

  defp find(zipper, predicate) do
    # Finds a node that matches the predicate.
    # Returns a zipper focused on the found node.
    cond do
      Zipper.end?(zipper) ->
        nil

      predicate.(Zipper.node(zipper)) ->
        zipper

      true ->
        Zipper.next(zipper)
        |> find(predicate)
    end
  end

  defp map_up(zipper, transform) do
    # Performs the provided transformation on all nodes above the provided zipper focus.

    case Zipper.up(zipper) do
      {:error, _} ->
        zipper

      zipper ->
        zipper
        |> Zipper.edit(transform)
        |> map_up(transform)
    end
  end

  defp do_insertion(zipper, point, index) do
    n = Zipper.node(zipper)
    bbox = bounding_box(n)

    {cut_dimension, cut, new_box} = get_cut(point, bbox)

    cond do
      cut <= Nx.to_number(bbox[0][cut_dimension]) ->
        leaf = %Leaf{point: point, index: index}

        branch = %Branch{
          cut_dimension: cut_dimension,
          cut_value: cut,
          count: n.count + leaf.count,
          left_child: leaf,
          right_child: n,
          bounding_box: new_box
        }

        Zipper.replace(zipper, branch)

      cut >= Nx.to_number(bbox[-1][cut_dimension]) ->
        leaf = %Leaf{point: point, index: index}

        branch = %Branch{
          cut_dimension: cut_dimension,
          cut_value: cut,
          count: n.count + leaf.count,
          left_child: n,
          right_child: leaf,
          bounding_box: new_box
        }

        Zipper.replace(zipper, branch)

      Nx.to_number(point[n.cut_dimension]) <= n.cut_value ->
        do_insertion(zipper |> Zipper.down() |> Zipper.leftmost(), point, index)

      true ->
        do_insertion(zipper |> Zipper.down() |> Zipper.rightmost(), point, index)
    end
  end

  defp get_cut(point, box) do
    # Gets a cut in a random dimension, weighted by each dimension's span.
    # Returns the selected dimension, cut value, and updated bounding box that includes the provided point.

    box = extend_bounding_box(box, point)

    box_span = Nx.subtract(box[-1], box[0])
    box_range = Nx.sum(box_span) |> Nx.to_number()
    span_sum = Nx.cumulative_sum(box_span) |> Nx.to_list()
    random = :rand.uniform_real() * box_range

    {cumulative_sum, cut_dimension} =
      span_sum
      |> Enum.with_index()
      |> Enum.find(&(elem(&1, 0) >= random))

    cut = Nx.to_number(box[0][cut_dimension]) + cumulative_sum - random

    {cut_dimension, cut, box}
  end

  defp extend_bounding_box(box, point) do
    Nx.concatenate([
      Nx.min(box[0], point),
      Nx.max(box[-1], point)
    ])
    |> Nx.reshape({2, :auto})
  end

  defp bounding_box(n) do
    case n do
      %Leaf{point: p} -> Nx.reshape(Nx.concatenate([p, p]), {2, :auto})
      %Branch{bounding_box: b} -> b
    end
  end
end

Sine Wave Anomaly Exercise

This simulation mirrors section 5.1 of the paper, where an anomaly is inserted into a sine wave.

Beyond being a simple sine wave, this simulation optionally includes a random jitter in each point.

# Simulation Variables
wave_period = 80
wave_magnitude = 40
wave_offset = 50
wave_randomness = 3
total_length = 1000

anomaly_start = 650
anomaly_value = 45
anomaly_length = 20

data =
  1..total_length
  |> Enum.map(fn i ->
    :math.sin(2 * :math.pi() / wave_period * i) *
      wave_magnitude + wave_offset + (wave_randomness * :rand.uniform() - 0.5 * wave_randomness)
  end)

constant_anomaly = Stream.cycle([anomaly_value]) |> Enum.take(anomaly_length)

data =
  Enum.concat([
    Enum.slice(data, 0..anomaly_start),
    constant_anomaly,
    Enum.slice(data, (anomaly_start + anomaly_length)..-1)
  ])

alias VegaLite, as: Vl

Vl.new(title: [text: "Sine Wave"], width: 800)
|> Vl.data_from_values(time: 1..Enum.count(data), data: data)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "time", type: :quantitative)
|> Vl.encode_field(:y, "data", type: :quantitative)

Just like in the paper, shingling is used in order for each point to represent a continuous set of datapoints. The extent of this shingling is controlled by shingle_length.

The forest is trained on a region of the data, described by the range value training_region, with samples_per_tree controlling how many shingled values are inserted into (learned by) each tree. Values are selected with Enum.take_random/2, which is based on reservoir sampling.

Tree count controls how many trees are in the forest. A greater number of trees typically results in a less sensitive forest—at a cost of execution speed and memory.

# Hyperparameters
shingle_length = 4
training_region = 0..500
samples_per_tree = 256
tree_count = 50
shingled_data =
  0..(shingle_length - 1)
  |> Stream.map(&Stream.drop(data, &1))
  |> Stream.zip()
  |> Enum.map(fn shingle ->
    Tuple.to_list(shingle)
    |> Nx.tensor()
  end)

training_data = Enum.slice(shingled_data, training_region)

make_tree = fn ->
  Enum.with_index(training_data)
  |> Enum.take_random(samples_per_tree)
  |> Enum.reduce(Tree.new(), fn {point, index}, t ->
    Tree.insert_point(t, point, index)
  end)
end

forest = Enum.map(1..tree_count, fn _ -> make_tree.() end)

Task.async_stream/3 is used to perform displacement calculations concurrently. This typically will get scheduled across all cores of the CPU.

Points do not need to be removed, since the forest and the trees are immutable. codisp is evaluated on a unique copy of the tree.

This is then collected into a tensor to perform a mean calculation—our final anomaly score for the shingled datapoint.

The calculation is performed for each datapoint.

results =
  Enum.map(shingled_data, fn point ->
    Task.async_stream(forest, fn tree ->
      Tree.insert_point(tree, point, "test")
      |> Tree.codisp("test")
    end)
    |> Enum.map(&elem(&1, 1))
    |> Nx.tensor()
    |> Nx.mean()
    |> Nx.to_number()
  end)

The beginning and end of the anomaly are found—as indicated by two large spikes in the outlier score.

graphable_results =
  Enum.zip(results, data)
  |> Enum.with_index()
  |> Enum.map(fn {{outlier_score, data}, index} ->
    [outlier_score: outlier_score, data: data, index: index]
  end)

Vl.new(title: [text: "Anomaly Detection"], width: 800)
|> Vl.data_from_values(graphable_results)
|> Vl.mark(:line)
|> Vl.transform(fold: ["data", "outlier_score"])
|> Vl.encode_field(:x, "index", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "key")

Finally, a distribution of outlier scores is shown. This can be helpful for tuning hyperparameters and alerting thresholds.

Vl.new(title: [text: "Distribution of Outlier Scores"], width: 800)
|> Vl.data_from_values(outlier_scores: results)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "outlier_scores", type: :quantitative, bin: [step: 0.5])
|> Vl.encode(:y, aggregate: :count)

Summary

This notebook has explored a partial implementation of the approach described in the RCF paper. While it does not share the high performance of implementations in languages containing mutable data structures, it has helped me build an intuition for how features and hyperparameters contribute to anomaly detection.

@ChristianAlexander
Copy link
Author

Future enhancements to consider building:

  • Zippers are only required for writes. Update the codisp calculation to model a hypothetical insert by operating on the tree itself (not involving the zipper).
  • Trees don't need to be built incrementally unless it's a streaming scenario. Add a tree-building method that subdivides the dataset, instead of finding a place where the point would be added.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment