Skip to content

Instantly share code, notes, and snippets.

@tomowarkar
Last active December 5, 2020 11:01
Show Gist options
  • Save tomowarkar/bba333f9870577b2d2ecb7d2bdef916e to your computer and use it in GitHub Desktop.
Save tomowarkar/bba333f9870577b2d2ecb7d2bdef916e to your computer and use it in GitHub Desktop.
Golang プリム法
package main
import "fmt"
import "bufio"
import "bytes"
var r *bufio.Reader
var testCases = []struct {
stdin string
want int
}{
{`5
-1 2 3 1 -1
2 -1 -1 4 -1
3 -1 -1 1 1
1 4 1 -1 3
-1 -1 1 3 -1`, 5},
{`7
-1 7 -1 5 -1 -1 -1
7 -1 8 9 7 -1 -1
-1 8 -1 -1 5 -1 -1
5 9 -1 -1 15 6 -1
-1 7 5 15 -1 8 9
-1 -1 -1 6 8 -1 11
-1 -1 -1 -1 9 11 -1`, 39},
}
func main() {
for i, tc := range testCases {
reader := bytes.NewReader([]byte(tc.stdin))
r = bufio.NewReader(reader)
got := solve()
if got != tc.want {
err := fmt.Errorf("test_case: %d want: %#v but get: %d", i, tc.want, got)
fmt.Println(err)
} else {
fmt.Printf("test_case: %d passed\n", i)
}
}
}
type prim struct {
n, sum int
minCost []int
used []bool
}
// sum, minCost, used の初期化
func initPrim(n int) *prim {
inf := 1 << 30
mc := make([]int, n)
mc[0] = inf
for i := 1; i < n; i *= 2 {
copy(mc[i:], mc[:i])
}
mc[0] = 0
return &prim{
n: n,
sum: 0,
minCost: mc,
used: make([]bool, n),
}
}
func solvePrim(n int, G [][]int, isConnect func(i int) bool) int {
p := initPrim(n)
for {
v := -1
for u := 0; u < p.n; u++ {
// 到達済みでない頂点のうちコスト最小で行くことができる頂点vを探索
if !p.used[u] && (v == -1 || p.minCost[u] < p.minCost[v]) {
v = u
}
}
if v == -1 {
break
}
p.used[v] = true
p.sum += p.minCost[v]
for u := 0; u < n; u++ {
// 頂点vから到達可能な頂点について最小コストの更新
if isConnect(G[v][u]) {
p.minCost[u] = min(p.minCost[u], G[v][u])
}
}
}
return p.sum
}
func solve() int {
n := ni()
G := make([][]int, n)
for i := range G {
G[i] = nis(n)
}
// 頂点i, jを結ぶ辺がない場合に辺の重みは-1が与えられる
return solvePrim(n, G, func(i int) bool {
return i != -1
})
}
func ni() int { var n int; fmt.Fscan(r, &n); return n }
func nis(size int) []int {
var n int
res := make([]int, size)
for i := range res {
fmt.Fscan(r, &n)
res[i] = n
}
return res
}
func min(a ...int) int {
if len(a) == 0 {
return 0
}
ret := a[0]
for _, e := range a {
if e < ret {
ret = e
}
}
return ret
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment