Last active
May 29, 2025 06:50
-
-
Save edsko/f6fc114d882d88b8ae01efe3465fbed6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Test.Util.LinearRegression ( | |
fitLine | |
, tests | |
) where | |
import Test.Tasty | |
import Test.Tasty.QuickCheck | |
import Test.Tensor.TestValue (TestValue) | |
{------------------------------------------------------------------------------- | |
Definition | |
-------------------------------------------------------------------------------} | |
-- | Simple linear regression (ordinary least squares) | |
-- | |
-- Returns offset and slope. | |
-- | |
-- See <https://en.wikipedia.org/wiki/Simple_linear_regression>, especially | |
-- <https://en.wikipedia.org/wiki/Simple_linear_regression#Expanded_formulas>. | |
fitLine :: forall a. Fractional a => [(a, a)] -> (a, a) | |
fitLine points = (aHat, bHat) | |
where | |
n :: a | |
n = fromIntegral $ length points | |
aHat, bHat :: a | |
aHat = (sum_y * sum_xx - sum_x * sum_xy) / denom | |
bHat = (n * sum_xy - sum_x * sum_y) / denom | |
sum_x, sum_y, sum_xx, sum_xy :: a | |
sum_x = sum $ map (\( x, _y) -> x ) points | |
sum_y = sum $ map (\(_x, y) -> y ) points | |
sum_xx = sum $ map (\( x, _y) -> x * x) points | |
sum_xy = sum $ map (\( x, y) -> x * y) points | |
denom :: a | |
denom = n * sum_xx - sum_x * sum_x | |
{------------------------------------------------------------------------------- | |
Tests | |
TODO: | |
* Test with non-uniform X values | |
* Test with noise | |
That said, we're not really concerned about the accuracy of the algorithm, | |
but merely with whether (1) it's implemented correctly and (2) whether it's | |
the /right/ algorithm. The tests as they are now seem to confirm both. | |
-------------------------------------------------------------------------------} | |
tests :: TestTree | |
tests = testGroup "Test.Util.LinearRegression" [ | |
testProperty "constant" prop_constant | |
, testProperty "diagonal" prop_diagonal | |
, testProperty "general" prop_general | |
] | |
-- | Test @y = c@ | |
prop_constant :: NumPoints -> TestValue -> Property | |
prop_constant (NumPoints n) c = | |
fitLine points | |
=== (c, 0) | |
where | |
points :: [(TestValue, TestValue)] | |
points = [ | |
(x, c) | |
| i <- [0 .. n - 1] | |
, let x = fromIntegral i | |
] | |
-- | Test @y = x@ | |
prop_diagonal :: NumPoints -> Property | |
prop_diagonal (NumPoints n) = | |
fitLine points | |
=== (0, 1) | |
where | |
points :: [(TestValue, TestValue)] | |
points = [ | |
(x, x) | |
| i <- [0 .. n - 1] | |
, let x = fromIntegral i | |
] | |
-- | Test the general case (but without noise) | |
prop_general :: NumPoints -> TestValue -> TestValue -> Property | |
prop_general (NumPoints n) a b = | |
fitLine points | |
=== (a, b) | |
where | |
points :: [(TestValue, TestValue)] | |
points = [ | |
(x, a + b * x) | |
| i <- [0 .. n - 1] | |
, let x = fromIntegral i | |
] | |
{------------------------------------------------------------------------------- | |
Auxiliary | |
-------------------------------------------------------------------------------} | |
-- | Number of points | |
-- | |
-- We can only fit a line if we have at least two points. | |
newtype NumPoints = NumPoints Int | |
deriving stock (Show) | |
instance Arbitrary NumPoints where | |
arbitrary = NumPoints . (+ 2) . getNonNegative <$> arbitrary | |
shrink (NumPoints n) = map NumPoints . filter (>= 2) $ shrink n |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment