diff options
Diffstat (limited to 'lib/httpcache/httpcache.go')
-rw-r--r-- | lib/httpcache/httpcache.go | 211 |
1 files changed, 211 insertions, 0 deletions
diff --git a/lib/httpcache/httpcache.go b/lib/httpcache/httpcache.go new file mode 100644 index 0000000..b2cc7fe --- /dev/null +++ b/lib/httpcache/httpcache.go @@ -0,0 +1,211 @@ +package httpcache + +import ( + "bufio" + hash "crypto/md5" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "sort" + "strings" +) + +var ( + UserAgent string + ModifyResponse func(url string, entry CacheEntry, resp *http.Response) *http.Response + CheckRedirect func(req *http.Request, via []*http.Request) error +) + +type CacheEntry string + +var memCache = map[string]CacheEntry{} + +type httpStatusError struct { + StatusCode int + Status string +} + +// Is implements the interface for [errors.Is]. +func (e *httpStatusError) Is(target error) bool { + switch target { + case os.ErrNotExist: + return e.StatusCode == http.StatusNotFound + default: + return false + } +} + +// Error implements [error]. +func (e *httpStatusError) Error() string { + return fmt.Sprintf("unexpected HTTP status: %v", e.Status) +} + +type transport struct{} + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + // Return an error for things that are the fault of things + // not-on-this-box. Panic for things that are the fault of + // this box. + + // Initialize. + if err := os.Mkdir(".http-cache", 0777); err != nil && !os.IsExist(err) { + panic(err) + } + + // Calculate cache-key. + u := req.URL.String() + cacheKey := url.QueryEscape(u) + hdrKeys := make([]string, 0, len(req.Header)) + for k := range req.Header { + switch k { + case "User-Agent": + case "Referer": + default: + hdrKeys = append(hdrKeys, http.CanonicalHeaderKey(k)) + } + } + sort.Strings(hdrKeys) + for _, k := range hdrKeys { + cacheKey += "|" + url.QueryEscape(k) + ":" + url.QueryEscape(req.Header[k][0]) + } + if len(cacheKey) >= 255 { + prefix := cacheKey[:255-(hash.Size*2)] + csum := hash.Sum([]byte(cacheKey)) + suffix := hex.EncodeToString(csum[:]) + cacheKey = prefix + suffix + } + cacheFile := filepath.Join(".http-cache", cacheKey) + + // Check the mem cache. + if _, ok := memCache[cacheKey]; ok { + fmt.Printf("GET|CACHE|MEM %q...", u) + goto end + } + // Check the file cache. + if bs, err := os.ReadFile(cacheFile); err == nil { + str := string(bs) + if strings.HasPrefix(str, "HTTP/") || strings.HasPrefix(str, "CLIENT/") { + fmt.Printf("GET|CACHE|FILE %q...", u) + memCache[cacheKey] = CacheEntry(str) + goto end + } + } + + // Do the request for real. + fmt.Printf("GET|NET %q...", u) + if resp, err := http.DefaultTransport.RoundTrip(req); err == nil { + var buf strings.Builder + if err := resp.Write(&buf); err != nil { + panic(err) + } + memCache[cacheKey] = CacheEntry(buf.String()) + } else { + memCache[cacheKey] = CacheEntry("CLIENT/" + err.Error()) + } + + // Record the response to the file cache. + if err := os.WriteFile(cacheFile, []byte(memCache[cacheKey]), 0666); err != nil { + panic(err) + } + +end: + // Turn the cache entry into an http.Response (or error) + var ret_resp *http.Response + var ret_err error + entry := memCache[cacheKey] + switch { + case strings.HasPrefix(string(entry), "HTTP/"): + var err error + ret_resp, err = http.ReadResponse(bufio.NewReader(strings.NewReader(string(entry))), nil) + if err != nil { + panic(fmt.Errorf("invalid cache entry: %v", err)) + } + if ModifyResponse != nil { + ret_resp = ModifyResponse(u, entry, ret_resp) + } + case strings.HasPrefix(string(entry), "CLIENT/"): + ret_err = errors.New(string(entry)[len("CLIENT/"):]) + default: + panic("invalid cache entry: invalid prefix") + } + + // Return. + if ret_err != nil { + fmt.Printf(" err\n") + } else { + fmt.Printf(" http %v\n", ret_resp.StatusCode) + } + return ret_resp, ret_err +} + +func Get(u string, hdr map[string]string) (string, error) { + if UserAgent == "" { + panic("main() must set the user agent string") + } + req, err := http.NewRequest(http.MethodGet, u, nil) + if err != nil { + panic(fmt.Errorf("should not happen: http.NewRequest: %v", err)) + } + req.Header.Set("User-Agent", UserAgent) + for k, v := range hdr { + req.Header.Add(k, v) + } + client := &http.Client{ + Transport: &transport{}, + CheckRedirect: CheckRedirect, + } + resp, err := client.Do(req) + if err != nil { + return "", err + } + if resp.StatusCode != http.StatusOK { + return "", &httpStatusError{StatusCode: resp.StatusCode, Status: resp.Status} + } + bs, err := io.ReadAll(resp.Body) + if err != nil { + panic(fmt.Errorf("should not happen: strings.Reader.Read: %v", err)) + } + return string(bs), nil +} + +func GetJSON(u string, hdr map[string]string, out any) error { + str, err := Get(u, hdr) + if err != nil { + return err + } + return json.Unmarshal([]byte(str), out) +} + +func GetPaginatedJSON[T any](uStr string, hdr map[string]string, out *[]T, pageFn func(i int) url.Values) error { + u, err := url.Parse(uStr) + if err != nil { + return err + } + query := u.Query() + + for i := 0; true; i++ { + pageParams := pageFn(i) + for k, v := range pageParams { + query[k] = v + } + + u.RawQuery = query.Encode() + var resp []T + if err := GetJSON(u.String(), hdr, &resp); err != nil { + return err + } + fmt.Printf(" -> %d records\n", len(resp)) + if len(resp) == 0 { + break + } + *out = append(*out, resp...) + } + + return nil +} |