Skip to content

Instantly share code, notes, and snippets.

@edsko
Last active May 29, 2025 06:50
Show Gist options
  • Save edsko/f6fc114d882d88b8ae01efe3465fbed6 to your computer and use it in GitHub Desktop.
Save edsko/f6fc114d882d88b8ae01efe3465fbed6 to your computer and use it in GitHub Desktop.
module Test.Util.LinearRegression (
fitLine
, tests
) where
import Test.Tasty
import Test.Tasty.QuickCheck
import Test.Tensor.TestValue (TestValue)
{-------------------------------------------------------------------------------
Definition
-------------------------------------------------------------------------------}
-- | Simple linear regression (ordinary least squares)
--
-- Returns offset and slope.
--
-- See <https://en.wikipedia.org/wiki/Simple_linear_regression>, especially
-- <https://en.wikipedia.org/wiki/Simple_linear_regression#Expanded_formulas>.
fitLine :: forall a. Fractional a => [(a, a)] -> (a, a)
fitLine points = (aHat, bHat)
where
n :: a
n = fromIntegral $ length points
aHat, bHat :: a
aHat = (sum_y * sum_xx - sum_x * sum_xy) / denom
bHat = (n * sum_xy - sum_x * sum_y) / denom
sum_x, sum_y, sum_xx, sum_xy :: a
sum_x = sum $ map (\( x, _y) -> x ) points
sum_y = sum $ map (\(_x, y) -> y ) points
sum_xx = sum $ map (\( x, _y) -> x * x) points
sum_xy = sum $ map (\( x, y) -> x * y) points
denom :: a
denom = n * sum_xx - sum_x * sum_x
{-------------------------------------------------------------------------------
Tests
TODO:
* Test with non-uniform X values
* Test with noise
That said, we're not really concerned about the accuracy of the algorithm,
but merely with whether (1) it's implemented correctly and (2) whether it's
the /right/ algorithm. The tests as they are now seem to confirm both.
-------------------------------------------------------------------------------}
tests :: TestTree
tests = testGroup "Test.Util.LinearRegression" [
testProperty "constant" prop_constant
, testProperty "diagonal" prop_diagonal
, testProperty "general" prop_general
]
-- | Test @y = c@
prop_constant :: NumPoints -> TestValue -> Property
prop_constant (NumPoints n) c =
fitLine points
=== (c, 0)
where
points :: [(TestValue, TestValue)]
points = [
(x, c)
| i <- [0 .. n - 1]
, let x = fromIntegral i
]
-- | Test @y = x@
prop_diagonal :: NumPoints -> Property
prop_diagonal (NumPoints n) =
fitLine points
=== (0, 1)
where
points :: [(TestValue, TestValue)]
points = [
(x, x)
| i <- [0 .. n - 1]
, let x = fromIntegral i
]
-- | Test the general case (but without noise)
prop_general :: NumPoints -> TestValue -> TestValue -> Property
prop_general (NumPoints n) a b =
fitLine points
=== (a, b)
where
points :: [(TestValue, TestValue)]
points = [
(x, a + b * x)
| i <- [0 .. n - 1]
, let x = fromIntegral i
]
{-------------------------------------------------------------------------------
Auxiliary
-------------------------------------------------------------------------------}
-- | Number of points
--
-- We can only fit a line if we have at least two points.
newtype NumPoints = NumPoints Int
deriving stock (Show)
instance Arbitrary NumPoints where
arbitrary = NumPoints . (+ 2) . getNonNegative <$> arbitrary
shrink (NumPoints n) = map NumPoints . filter (>= 2) $ shrink n
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment