|
@@ -0,0 +1,169 @@
|
|
|
+package security
|
|
|
+
|
|
|
+import (
|
|
|
+ "crypto/hmac"
|
|
|
+ "crypto/md5"
|
|
|
+ "crypto/sha256"
|
|
|
+ "encoding/base64"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "log"
|
|
|
+ "net/http"
|
|
|
+ "os"
|
|
|
+ "strconv"
|
|
|
+ "strings"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/stretchr/testify/assert"
|
|
|
+ "github.com/tal-tech/go-zero/core/codec"
|
|
|
+ "github.com/tal-tech/go-zero/core/fs"
|
|
|
+ "github.com/tal-tech/go-zero/rest/httpx"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ pubKey = `-----BEGIN PUBLIC KEY-----
|
|
|
+MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCyeDYV2ieOtNDi6tuNtAbmUjN9
|
|
|
+pTHluAU5yiKEz8826QohcxqUKP3hybZBcm60p+rUxMAJFBJ8Dt+UJ6sEMzrf1rOF
|
|
|
+YOImVvORkXjpFU7sCJkhnLMs/kxtRzcZJG6ADUlG4GDCNcZpY/qELEvwgm2kCcHi
|
|
|
+tGC2mO8opFFFHTR0aQIDAQAB
|
|
|
+-----END PUBLIC KEY-----`
|
|
|
+ priKey = `-----BEGIN RSA PRIVATE KEY-----
|
|
|
+MIICXQIBAAKBgQCyeDYV2ieOtNDi6tuNtAbmUjN9pTHluAU5yiKEz8826QohcxqU
|
|
|
+KP3hybZBcm60p+rUxMAJFBJ8Dt+UJ6sEMzrf1rOFYOImVvORkXjpFU7sCJkhnLMs
|
|
|
+/kxtRzcZJG6ADUlG4GDCNcZpY/qELEvwgm2kCcHitGC2mO8opFFFHTR0aQIDAQAB
|
|
|
+AoGAcENv+jT9VyZkk6karLuG75DbtPiaN5+XIfAF4Ld76FWVOs9V88cJVON20xpx
|
|
|
+ixBphqexCMToj8MnXuHJEN5M9H15XXx/9IuiMm3FOw0i6o0+4V8XwHr47siT6T+r
|
|
|
+HuZEyXER/2qrm0nxyC17TXtd/+TtpfQWSbivl6xcAEo9RRECQQDj6OR6AbMQAIDn
|
|
|
+v+AhP/y7duDZimWJIuMwhigA1T2qDbtOoAEcjv3DB1dAswJ7clcnkxI9a6/0RDF9
|
|
|
+0IEHUcX9AkEAyHdcegWiayEnbatxWcNWm1/5jFnCN+GTRRFrOhBCyFr2ZdjFV4T+
|
|
|
+acGtG6omXWaZJy1GZz6pybOGy93NwLB93QJARKMJ0/iZDbOpHqI5hKn5mhd2Je25
|
|
|
+IHDCTQXKHF4cAQ+7njUvwIMLx2V5kIGYuMa5mrB/KMI6rmyvHv3hLewhnQJBAMMb
|
|
|
+cPUOENMllINnzk2oEd3tXiscnSvYL4aUeoErnGP2LERZ40/YD+mMZ9g6FVboaX04
|
|
|
+0oHf+k5mnXZD7WJyJD0CQQDJ2HyFbNaUUHK+lcifCibfzKTgmnNh9ZpePFumgJzI
|
|
|
+EfFE5H+nzsbbry2XgJbWzRNvuFTOLWn4zM+aFyy9WvbO
|
|
|
+-----END RSA PRIVATE KEY-----`
|
|
|
+ body = "hello world!"
|
|
|
+)
|
|
|
+
|
|
|
+var key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
|
|
+
|
|
|
+func TestContentSecurity(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ mode string
|
|
|
+ extraKey string
|
|
|
+ extraSecret string
|
|
|
+ extraTime string
|
|
|
+ err error
|
|
|
+ code int
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "encrypted",
|
|
|
+ mode: "1",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "unencrypted",
|
|
|
+ mode: "0",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "bad content type",
|
|
|
+ mode: "a",
|
|
|
+ err: ErrInvalidContentType,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "bad secret",
|
|
|
+ mode: "1",
|
|
|
+ extraSecret: "any",
|
|
|
+ err: ErrInvalidSecret,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "bad key",
|
|
|
+ mode: "1",
|
|
|
+ extraKey: "any",
|
|
|
+ err: ErrInvalidKey,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "bad time",
|
|
|
+ mode: "1",
|
|
|
+ extraTime: "any",
|
|
|
+ code: httpx.CodeSignatureInvalidHeader,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, test := range tests {
|
|
|
+ test := test
|
|
|
+ t.Run(test.name, func(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ r, err := http.NewRequest(http.MethodPost, "http://localhost:3333/a/b?c=first&d=second",
|
|
|
+ strings.NewReader(body))
|
|
|
+ assert.Nil(t, err)
|
|
|
+
|
|
|
+ timestamp := time.Now().Unix()
|
|
|
+ sha := sha256.New()
|
|
|
+ sha.Write([]byte(body))
|
|
|
+ bodySign := fmt.Sprintf("%x", sha.Sum(nil))
|
|
|
+ contentOfSign := strings.Join([]string{
|
|
|
+ strconv.FormatInt(timestamp, 10),
|
|
|
+ http.MethodPost,
|
|
|
+ r.URL.Path,
|
|
|
+ r.URL.RawQuery,
|
|
|
+ bodySign,
|
|
|
+ }, "\n")
|
|
|
+ sign := hs256(key, contentOfSign)
|
|
|
+ content := strings.Join([]string{
|
|
|
+ "version=v1",
|
|
|
+ "type=" + test.mode,
|
|
|
+ fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)) + test.extraKey,
|
|
|
+ "time=" + strconv.FormatInt(timestamp, 10) + test.extraTime,
|
|
|
+ }, "; ")
|
|
|
+
|
|
|
+ encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ output, err := encrypter.Encrypt([]byte(content))
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ encryptedContent := base64.StdEncoding.EncodeToString(output)
|
|
|
+ r.Header.Set("X-Content-Security", strings.Join([]string{
|
|
|
+ fmt.Sprintf("key=%s", fingerprint(pubKey)),
|
|
|
+ "secret=" + encryptedContent + test.extraSecret,
|
|
|
+ "signature=" + sign,
|
|
|
+ }, "; "))
|
|
|
+
|
|
|
+ file, err := fs.TempFilenameWithText(priKey)
|
|
|
+ assert.Nil(t, err)
|
|
|
+ defer os.Remove(file)
|
|
|
+
|
|
|
+ dec, err := codec.NewRsaDecrypter(file)
|
|
|
+ assert.Nil(t, err)
|
|
|
+
|
|
|
+ header, err := ParseContentSecurity(map[string]codec.RsaDecrypter{
|
|
|
+ fingerprint(pubKey): dec,
|
|
|
+ }, r)
|
|
|
+ assert.Equal(t, test.err, err)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ assert.Equal(t, test.code, VerifySignature(r, header, time.Minute))
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func fingerprint(key string) string {
|
|
|
+ h := md5.New()
|
|
|
+ io.WriteString(h, key)
|
|
|
+ return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
|
|
+}
|
|
|
+
|
|
|
+func hs256(key []byte, body string) string {
|
|
|
+ h := hmac.New(sha256.New, key)
|
|
|
+ io.WriteString(h, body)
|
|
|
+ return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
|
|
+}
|