/*

3rd smallest :  15
6th smallest: 75
4th:  --> 25
Class Node {
  int val;
  Node left;
  Node right;
}

                            100
                          75
                        50
                      35
                    20
                  1
                -1
              -2
              
                                     100
                                
                                 50        150 
                            25       75        200
                        5     30
                      1  15


Node findKthSmallest(Node root, int k){
  
  
}

*/
import java.io.*;

import java.util.Map.*;
import java.util.*;

class MyCode {
  static Map<Integer, Integer> treeSize;
	public static void main (String[] args) {
		System.out.println("Hello Java");
    
    // create a BST
    Node root = new Node(100);
    root.left = new Node(50);
    root.right = new Node(150);
    root.left.left = new Node(25);
    root.left.right = new Node(75);
    root.left.left.left = new Node(5);
    root.left.left.right = new Node(30);
    root.left.left.left.left = new Node(1);
    root.left.left.left.right = new Node(15);
    root.right.right = new Node(200);
    
    treeSize = new HashMap<>();
    int k = 1;
    System.out.println("Kth smallest value: " + findKthSmallest(root, k).val);
    k = 8;
    System.out.println("Kth smallest value: " + findKthSmallest(root, k).val);
    k = 10;
    System.out.println("Kth smallest value: " + findKthSmallest(root, k).val);
    // System.out.println("Kth smallest value: " + findKthSmallest(root, k).val);
    // System.out.println("Kth smallest value: " + findKthSmallest(root, k).val);
    
	}

  public static Node findKthSmallest(Node root, int k){
    // find the size of the left tree
    if(root == null){ return root; }
    int lSize = 0;
    if(root.left != null && treeSize.containsKey(root.left.val)){
      lSize = treeSize.get(root.left.val);
    }else{
      lSize = getTreeSize(root.left); // O(n)
      if(root.left != null){
        treeSize.put(root.left.val, lSize);
      }
    }
    // if root size is the size that we are looking for then return root
    // otherwise discover left tree
    // or right tree
    if(lSize + 1 == k){
      return root;
    }else if(lSize + 1 > k){
      return findKthSmallest(root.left, k);
    } else{
      return findKthSmallest(root.right, k - lSize - 1);
    }
  }
  
  public static int getTreeSize(Node root){
    // base case
    if(root == null){
      return 0;
    }
    // recursion case
    return 1 + getTreeSize(root.left) + getTreeSize(root.right);
  }
  public static class Node {
    int val;
    Node left;
    Node right;
    public Node(int val){
      this.val = val;
    }
  }
}