-
-
Save avary/9f3c1e843430aa2682e2f5616496e484 to your computer and use it in GitHub Desktop.
concurrency safe web scraper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"io" | |
"net/http" | |
"net/url" | |
"slices" | |
"strings" | |
"sync" | |
"golang.org/x/net/html" | |
) | |
var wg = &sync.WaitGroup{} | |
var mut = sync.Mutex{} | |
type safeSlice struct { | |
sync.RWMutex | |
data []string | |
} | |
func (s *safeSlice) Contains(str string) bool { | |
s.RLock() | |
defer s.RUnlock() | |
return slices.Contains(s.data, str) | |
} | |
func (s *safeSlice) Add(str string) { | |
s.Lock() | |
defer s.Unlock() | |
s.data = append(s.data, str) | |
} | |
func (s *safeSlice) Length() int { | |
s.RLock() | |
defer s.RUnlock() | |
return len(s.data) | |
} | |
var seen = &safeSlice{} | |
var results = &safeSlice{} | |
func main() { | |
const URL = "https://scrape-me.dreamsofcode.io" | |
wg.Add(1) | |
go getValidURLs(URL) | |
wg.Wait() | |
fmt.Println("Total URLs found:", results.Length()) | |
fmt.Println("Total seen URLs:", seen.Length()) | |
} | |
func (s *safeSlice) AddIfNotSeen(str string) bool { | |
s.Lock() | |
defer s.Unlock() | |
if slices.Contains(s.data, str) { | |
return false | |
} | |
s.data = append(s.data, str) | |
return true | |
} | |
func getValidURLs(url string) { | |
defer wg.Done() | |
// if we know the url has been seen, go away | |
if seen.Contains(url) { | |
return | |
} | |
statusCode, node := fetchURL(url) | |
// if the status is not ok, we've seen it. Return | |
if statusCode != http.StatusOK { | |
seen.Add(url) | |
return | |
} | |
if seen.AddIfNotSeen(url) { | |
results.Add(url) | |
} | |
hrefs := getHrefs(url, node) | |
for _, href := range hrefs { | |
wg.Add(1) | |
// recursively call getValidURLs | |
go getValidURLs(href) | |
} | |
} | |
func getHrefs(baseHost string, n *html.Node) []string { | |
result := []string{} | |
if n.Type == html.ElementNode && n.Data == "a" { | |
for _, a := range n.Attr { | |
if a.Key == "href" && isBaseHost(baseHost, a.Val) { | |
parsedURL, err := url.Parse(baseHost) | |
if err != nil { | |
panic(err) | |
} | |
result = append(result, addPath(parsedURL.Scheme+"://"+parsedURL.Host, a.Val)) | |
break | |
} | |
} | |
} | |
for c := n.FirstChild; c != nil; c = c.NextSibling { | |
result = append(result, getHrefs(baseHost, c)...) | |
} | |
return result | |
} | |
func isBaseHost(baseHost string, href string) bool { | |
return strings.HasPrefix(href, "/") || strings.HasPrefix(href, baseHost) | |
} | |
func addPath(baseHost string, href string) string { | |
// If href is already a full URL, return it | |
if strings.HasPrefix(href, baseHost) { | |
return href | |
} | |
// Remove leading slash from href and trailing slash from baseHost | |
href = strings.TrimPrefix(href, "/") | |
baseHost = strings.TrimSuffix(baseHost, "/") | |
return baseHost + "/" + href | |
} | |
func fetchURL(url string) (statusCode int, node *html.Node) { | |
fmt.Println("Checking URL:", url) | |
response, err := http.Get(url) | |
if err != nil { | |
panic(err) | |
} | |
defer response.Body.Close() | |
if response.StatusCode != http.StatusOK { | |
return response.StatusCode, nil | |
} | |
bytes, err := io.ReadAll(response.Body) | |
if err != nil { | |
panic(err) | |
} | |
rawHtml := string(bytes) | |
_node, err := html.Parse(strings.NewReader(rawHtml)) | |
if err != nil { | |
panic(err) | |
} | |
return response.StatusCode, _node | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment