Skip to content

Instantly share code, notes, and snippets.

@Mozk0
Last active December 22, 2015 04:38
Show Gist options
  • Save Mozk0/6418331 to your computer and use it in GitHub Desktop.
Save Mozk0/6418331 to your computer and use it in GitHub Desktop.
クロッシング問題をそれぞれHaskell, D, C++で書いた。全てマージソートでマージ中に転倒数を数えていくアルゴリズム。 crossing.hs : vectorを使って高速化 200ms crossing2.hs : listのみを使った 1200ms crossing.d : 130ms crossing.c++ : 80ms
クロッシング問題をそれぞれHaskell, D, C++で書いた。全てマージソートでマージ中に転倒数を数えていくアルゴリズム。
crossing.hs : 200ms vectorを使って高速化
crossing2.hs : 1200ms listのみを使った
crossing.d : 130ms
crossing.c++ : 80ms
#include <vector>
#include <iostream>
#include <fstream>
#include <cstddef>
#include <string>
namespace {
size_t Merge(std::vector<size_t>& a, size_t i, size_t chunk_size, std::vector<size_t>& tmp) {
std::vector<size_t>::iterator a1 = a.begin() + i;
size_t l1 = chunk_size;
std::vector<size_t>::iterator a2 = a1 + l1;
size_t l2 = (a.size() > i + chunk_size * 2) ? chunk_size : a.size() - i - chunk_size;
tmp.clear();
size_t res = 0;
for (size_t i1 = 0, i2 = 0; i1 != l1 || i2 != l2; ) {
if (i1 == l1) {
tmp.push_back(a2[i2++]);
} else if (i2 == l2) {
tmp.push_back(a1[i1++]);
} else if (a1[i1] < a2[i2]) {
tmp.push_back(a1[i1++]);
} else if (a2[i2] < a1[i1]) {
tmp.push_back(a2[i2++]);
res += l1 - i1;
} else {
std::cerr << "Oops" << std::endl;
}
}
for (size_t di = 0; di < tmp.size(); ++di)
a[i + di] = tmp[di];
return res;
}
size_t CountCross(std::vector<size_t>& a) {
size_t res = 0;
std::vector<size_t> tmp(a.size());
for (size_t chunk_size = 1; chunk_size < a.size(); chunk_size *= 2)
for (size_t i = 0; i + chunk_size < a.size(); i += chunk_size * 2)
res += Merge(a, i, chunk_size, tmp);
return res;
}
}
int main(int argc, char **argv) {
char* file_name = argv[1];
std::ifstream file(file_name);
std::vector<size_t> a;
std::string line;
while (std::getline(file, line))
a.push_back(atoi(line.c_str()));
std::cout << CountCross(a) << std::endl;
}
import std.stdio, std.file, std.algorithm, std.array, std.conv;
void main(string[] args) {
// read input into 'a'
size_t[] a;
foreach (l; File(args[1]).byLine())
a ~= l.removeCRLF.to!size_t();
// Go
stdout.writeln(countIntersections(a));
}
size_t countIntersections(size_t[] a) {
// Do merge sort
size_t res = 0;
auto tmp = new size_t[a.length];
for (size_t chunkSize = 1; chunkSize < a.length; chunkSize *= 2)
for (size_t i = 0; i + chunkSize < a.length; i += chunkSize * 2)
res += merge(a[i .. min(a.length, i + chunkSize * 2)], chunkSize, tmp);
return res;
}
size_t merge(size_t[] a, size_t chunkSize, size_t[] tmp) {
size_t res = 0;
// Merge a0 and a1, and count how many intersections are destroyed
auto a0 = a[0 .. chunkSize];
auto a1 = a[chunkSize .. $];
size_t i0 = 0;
size_t i1 = 0;
size_t i = 0;
while (i0 != a0.length || i1 != a1.length) {
if (i0 == a0.length) {
tmp[i++] = a1[i1++];
} else if (i1 == a1.length) {
tmp[i++] = a0[i0++];
} else if (a1[i1] <= a0[i0]) {
tmp[i++] = a1[i1++];
// intersections are destroyed in this case
res += a0.length - i0;
} else if (a0[i0] <= a1[i1]) {
tmp[i++] = a0[i0++];
} else {
assert(false);
}
}
foreach (j, ref x; a)
x = tmp[j];
return res;
}
// If cs has a CR or LF or CRLF at the end, remove it.
@property C[] removeCRLF(C)(C[] cs) {
if (cs.length > 2 && cs[$ - 1] == '\n' && cs[$ - 2] == '\r')
return cs[0 .. $ - 2];
else if (!cs.empty && cs[$ - 1] == '\n')
return cs[0 .. $ - 1];
else if (!cs.empty && cs[$ - 1] == '\r')
return cs[0 .. $ - 1];
else
return cs;
}
{-# LANGUAGE ViewPatterns, BangPatterns #-}
module Main (
main
) where
import qualified Data.Text as T
import qualified Data.Text.IO as TO
import qualified Data.Text.Read as TR
import qualified Data.Vector.Unboxed as VU
import Prelude hiding (length)
import Control.Monad (forM_)
import Control.Monad.ST as ST
import Data.Either.Unwrap (fromRight)
import Data.Functor ((<$>))
import Data.STRef as STRef
import Data.Vector.Unboxed.Mutable
(length, unsafeSlice, unsafeWrite, unsafeRead, new, STVector)
main :: IO ()
main = print =<< countInversions <$> map parseInt <$> T.lines <$> TO.getContents
parseInt :: Integral a => T.Text -> a
parseInt = fst . fromRight. TR.decimal
countInversions :: [Int] -> Int
countInversions xs' = runST $ do
count <- newSTRef (0 :: Int)
xs <- VU.unsafeThaw $ VU.fromList xs'
let len = length xs
buf <- new len
let chunkSizes = takeWhile (< len) $ iterate (* 2) 1
forM_ chunkSizes $ \chunkSize ->
forM_ [0, chunkSize * 2 .. len - chunkSize - 1] $ \i -> do
let low = i
let mid = i + chunkSize
let high = (i + chunkSize * 2) `min` len
let left = unsafeSlice low (mid - low) xs
let right = unsafeSlice mid (high - mid) xs
deltaCount <- merge left right buf 0 0 0 0
modifySTRef' count (+deltaCount)
readSTRef count
where
merge :: STVector s Int -> STVector s Int -> STVector s Int -> Int -> Int -> Int -> Int -> ST s Int
merge a0@(length -> l0) a1@(length -> l1) buf i0 i1 i !count
{- Merge a0[i0 .. $] and a1[i1 .. $] into buf[i .. $],
finally write the contents of buf back to a0 and a1. -}
| i0 == l0 && i1 == l1 = do
forM_ [0 .. l0 - 1] $ \j ->
unsafeWrite a0 j =<< unsafeRead buf j
forM_ [0 .. l1 - 1] $ \j' ->
unsafeWrite a1 j' =<< unsafeRead buf (j' + l0)
return count
| i0 == l0 = do
unsafeWrite buf i =<< unsafeRead a1 i1
merge a0 a1 buf i0 (i1 + 1) (i + 1) count
| i1 == l1 = do
unsafeWrite buf i =<< unsafeRead a0 i0
merge a0 a1 buf (i0 + 1) i1 (i + 1) count
| otherwise = do
x0 <- unsafeRead a0 i0
x1 <- unsafeRead a1 i1
case compare x0 x1 of
LT -> do
unsafeWrite buf i =<< unsafeRead a0 i0
merge a0 a1 buf (i0 + 1) i1 (i + 1) count
_ -> do
unsafeWrite buf i =<< unsafeRead a1 i1
merge a0 a1 buf i0 (i1 + 1) (i + 1) (count + l0 - i0)
module Main (
main
) where
import qualified Data.Text as T
import qualified Data.Text.IO as TO
import qualified Data.Text.Read as TR
import Data.Either.Unwrap (fromRight)
import Data.Functor ((<$>))
main :: IO ()
main = print =<< countInversions <$> map parseInt <$> T.lines <$> TO.getContents
parseInt :: Integral a => T.Text -> a
parseInt = fst . fromRight. TR.decimal
countInversions :: [Int] -> Int
countInversions xs' = fst $ countInversions' xs' (length xs')
where
countInversions' [] _ = (0, [] )
countInversions' (a:[]) _ = (0, [a])
countInversions' xs len = (leftCount + rightCount + mergeCount, sorted)
where
mid = len `div` 2
(left, right) = splitAt mid xs
(leftCount, sortedLeft) = countInversions' left mid
(rightCount, sortedRight) = countInversions' right (len - mid)
(mergeCount, sorted) = merge sortedLeft sortedRight 0 [] mid
merge [] a1 count acc _ = (count, reverse acc ++ a1)
merge a0 [] count acc _ = (count, reverse acc ++ a0)
merge a0@(x0:a0') a1@(x1:a1') count acc l
| x0 < x1 = merge a0' a1 count (x0:acc) (l - 1)
| otherwise = merge a0 a1' (count + l) (x1:acc) l
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment