Created
October 23, 2024 17:04
-
-
Save sean9999/47b7a3fecb81d7bc2837aace385698f7 to your computer and use it in GitHub Desktop.
concurrency safe web scraper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"io" | |
"net/http" | |
"net/url" | |
"slices" | |
"strings" | |
"sync" | |
"golang.org/x/net/html" | |
) | |
var wg = &sync.WaitGroup{} | |
var mut = sync.Mutex{} | |
type safeSlice struct { | |
sync.RWMutex | |
data []string | |
} | |
func (s *safeSlice) Contains(str string) bool { | |
s.RLock() | |
defer s.RUnlock() | |
return slices.Contains(s.data, str) | |
} | |
func (s *safeSlice) Add(str string) { | |
s.Lock() | |
defer s.Unlock() | |
s.data = append(s.data, str) | |
} | |
func (s *safeSlice) Length() int { | |
s.RLock() | |
defer s.RUnlock() | |
return len(s.data) | |
} | |
var seen = &safeSlice{} | |
var results = &safeSlice{} | |
func main() { | |
const URL = "https://scrape-me.dreamsofcode.io" | |
wg.Add(1) | |
go getValidURLs(URL) | |
wg.Wait() | |
fmt.Println("Total URLs found:", results.Length()) | |
fmt.Println("Total seen URLs:", seen.Length()) | |
} | |
func getValidURLs(url string) { | |
defer wg.Done() | |
// if we know the url has been seen, go away | |
if seen.Contains(url) { | |
return | |
} | |
statusCode, node := fetchURL(url) | |
// if the status is not ok, we've seen it. Return | |
if statusCode != http.StatusOK { | |
seen.Add(url) | |
return | |
} | |
// concurrency requires that we check this again | |
// as fetchURL takes time | |
if !seen.Contains(url) { | |
results.Add(url) | |
seen.Add(url) | |
} | |
hrefs := getHrefs(url, node) | |
for _, href := range hrefs { | |
wg.Add(1) | |
// recursively call getValidURLs | |
go getValidURLs(href) | |
} | |
} | |
func getHrefs(baseHost string, n *html.Node) []string { | |
result := []string{} | |
if n.Type == html.ElementNode && n.Data == "a" { | |
for _, a := range n.Attr { | |
if a.Key == "href" && isBaseHost(baseHost, a.Val) { | |
parsedURL, err := url.Parse(baseHost) | |
if err != nil { | |
panic(err) | |
} | |
result = append(result, addPath(parsedURL.Scheme+"://"+parsedURL.Host, a.Val)) | |
break | |
} | |
} | |
} | |
for c := n.FirstChild; c != nil; c = c.NextSibling { | |
result = append(result, getHrefs(baseHost, c)...) | |
} | |
return result | |
} | |
func isBaseHost(baseHost string, href string) bool { | |
return strings.HasPrefix(href, "/") || strings.HasPrefix(href, baseHost) | |
} | |
func addPath(baseHost string, href string) string { | |
// If href is already a full URL, return it | |
if strings.HasPrefix(href, baseHost) { | |
return href | |
} | |
// Remove leading slash from href and trailing slash from baseHost | |
href = strings.TrimPrefix(href, "/") | |
baseHost = strings.TrimSuffix(baseHost, "/") | |
return baseHost + "/" + href | |
} | |
func fetchURL(url string) (statusCode int, node *html.Node) { | |
fmt.Println("Checking URL:", url) | |
response, err := http.Get(url) | |
if err != nil { | |
panic(err) | |
} | |
defer response.Body.Close() | |
if response.StatusCode != http.StatusOK { | |
return response.StatusCode, nil | |
} | |
bytes, err := io.ReadAll(response.Body) | |
if err != nil { | |
panic(err) | |
} | |
rawHtml := string(bytes) | |
_node, err := html.Parse(strings.NewReader(rawHtml)) | |
if err != nil { | |
panic(err) | |
} | |
return response.StatusCode, _node | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
On lines 74-77 you can see that we are doing three operations atomically, when really they should be done in a transaction. This could be considered a race condition. I leave it as a exercise to the reader