summaryrefslogtreecommitdiff
path: root/lib/httpcache/httpcache.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/httpcache/httpcache.go')
-rw-r--r--lib/httpcache/httpcache.go211
1 files changed, 211 insertions, 0 deletions
diff --git a/lib/httpcache/httpcache.go b/lib/httpcache/httpcache.go
new file mode 100644
index 0000000..b2cc7fe
--- /dev/null
+++ b/lib/httpcache/httpcache.go
@@ -0,0 +1,211 @@
+package httpcache
+
+import (
+ "bufio"
+ hash "crypto/md5"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+)
+
+var (
+ UserAgent string
+ ModifyResponse func(url string, entry CacheEntry, resp *http.Response) *http.Response
+ CheckRedirect func(req *http.Request, via []*http.Request) error
+)
+
+type CacheEntry string
+
+var memCache = map[string]CacheEntry{}
+
+type httpStatusError struct {
+ StatusCode int
+ Status string
+}
+
+// Is implements the interface for [errors.Is].
+func (e *httpStatusError) Is(target error) bool {
+ switch target {
+ case os.ErrNotExist:
+ return e.StatusCode == http.StatusNotFound
+ default:
+ return false
+ }
+}
+
+// Error implements [error].
+func (e *httpStatusError) Error() string {
+ return fmt.Sprintf("unexpected HTTP status: %v", e.Status)
+}
+
+type transport struct{}
+
+func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
+ // Return an error for things that are the fault of things
+ // not-on-this-box. Panic for things that are the fault of
+ // this box.
+
+ // Initialize.
+ if err := os.Mkdir(".http-cache", 0777); err != nil && !os.IsExist(err) {
+ panic(err)
+ }
+
+ // Calculate cache-key.
+ u := req.URL.String()
+ cacheKey := url.QueryEscape(u)
+ hdrKeys := make([]string, 0, len(req.Header))
+ for k := range req.Header {
+ switch k {
+ case "User-Agent":
+ case "Referer":
+ default:
+ hdrKeys = append(hdrKeys, http.CanonicalHeaderKey(k))
+ }
+ }
+ sort.Strings(hdrKeys)
+ for _, k := range hdrKeys {
+ cacheKey += "|" + url.QueryEscape(k) + ":" + url.QueryEscape(req.Header[k][0])
+ }
+ if len(cacheKey) >= 255 {
+ prefix := cacheKey[:255-(hash.Size*2)]
+ csum := hash.Sum([]byte(cacheKey))
+ suffix := hex.EncodeToString(csum[:])
+ cacheKey = prefix + suffix
+ }
+ cacheFile := filepath.Join(".http-cache", cacheKey)
+
+ // Check the mem cache.
+ if _, ok := memCache[cacheKey]; ok {
+ fmt.Printf("GET|CACHE|MEM %q...", u)
+ goto end
+ }
+ // Check the file cache.
+ if bs, err := os.ReadFile(cacheFile); err == nil {
+ str := string(bs)
+ if strings.HasPrefix(str, "HTTP/") || strings.HasPrefix(str, "CLIENT/") {
+ fmt.Printf("GET|CACHE|FILE %q...", u)
+ memCache[cacheKey] = CacheEntry(str)
+ goto end
+ }
+ }
+
+ // Do the request for real.
+ fmt.Printf("GET|NET %q...", u)
+ if resp, err := http.DefaultTransport.RoundTrip(req); err == nil {
+ var buf strings.Builder
+ if err := resp.Write(&buf); err != nil {
+ panic(err)
+ }
+ memCache[cacheKey] = CacheEntry(buf.String())
+ } else {
+ memCache[cacheKey] = CacheEntry("CLIENT/" + err.Error())
+ }
+
+ // Record the response to the file cache.
+ if err := os.WriteFile(cacheFile, []byte(memCache[cacheKey]), 0666); err != nil {
+ panic(err)
+ }
+
+end:
+ // Turn the cache entry into an http.Response (or error)
+ var ret_resp *http.Response
+ var ret_err error
+ entry := memCache[cacheKey]
+ switch {
+ case strings.HasPrefix(string(entry), "HTTP/"):
+ var err error
+ ret_resp, err = http.ReadResponse(bufio.NewReader(strings.NewReader(string(entry))), nil)
+ if err != nil {
+ panic(fmt.Errorf("invalid cache entry: %v", err))
+ }
+ if ModifyResponse != nil {
+ ret_resp = ModifyResponse(u, entry, ret_resp)
+ }
+ case strings.HasPrefix(string(entry), "CLIENT/"):
+ ret_err = errors.New(string(entry)[len("CLIENT/"):])
+ default:
+ panic("invalid cache entry: invalid prefix")
+ }
+
+ // Return.
+ if ret_err != nil {
+ fmt.Printf(" err\n")
+ } else {
+ fmt.Printf(" http %v\n", ret_resp.StatusCode)
+ }
+ return ret_resp, ret_err
+}
+
+func Get(u string, hdr map[string]string) (string, error) {
+ if UserAgent == "" {
+ panic("main() must set the user agent string")
+ }
+ req, err := http.NewRequest(http.MethodGet, u, nil)
+ if err != nil {
+ panic(fmt.Errorf("should not happen: http.NewRequest: %v", err))
+ }
+ req.Header.Set("User-Agent", UserAgent)
+ for k, v := range hdr {
+ req.Header.Add(k, v)
+ }
+ client := &http.Client{
+ Transport: &transport{},
+ CheckRedirect: CheckRedirect,
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return "", &httpStatusError{StatusCode: resp.StatusCode, Status: resp.Status}
+ }
+ bs, err := io.ReadAll(resp.Body)
+ if err != nil {
+ panic(fmt.Errorf("should not happen: strings.Reader.Read: %v", err))
+ }
+ return string(bs), nil
+}
+
+func GetJSON(u string, hdr map[string]string, out any) error {
+ str, err := Get(u, hdr)
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal([]byte(str), out)
+}
+
+func GetPaginatedJSON[T any](uStr string, hdr map[string]string, out *[]T, pageFn func(i int) url.Values) error {
+ u, err := url.Parse(uStr)
+ if err != nil {
+ return err
+ }
+ query := u.Query()
+
+ for i := 0; true; i++ {
+ pageParams := pageFn(i)
+ for k, v := range pageParams {
+ query[k] = v
+ }
+
+ u.RawQuery = query.Encode()
+ var resp []T
+ if err := GetJSON(u.String(), hdr, &resp); err != nil {
+ return err
+ }
+ fmt.Printf(" -> %d records\n", len(resp))
+ if len(resp) == 0 {
+ break
+ }
+ *out = append(*out, resp...)
+ }
+
+ return nil
+}