package tkn20 import ( "bytes" "crypto/rand" "encoding/json" "fmt" "os" "testing" ) type TestCase struct { Policy string Success bool Attrs map[string]string `json:"attributes"` } func TestConcurrentDecryption(t *testing.T) { var tests []TestCase buf, _ := os.ReadFile("testdata/policies.json") err := json.Unmarshal(buf, &tests) if err != nil { t.Fatal(err) } msg := []byte("must have the precious") for i, test := range tests { t.Run(fmt.Sprintf("TestConcurrentDecryption:#%d", i), func(t *testing.T) { pk, msk, err := Setup(rand.Reader) if err != nil { t.Fatal(err) } policy := Policy{} err = policy.FromString(test.Policy) if err != nil { t.Fatal(err) } ct, err := pk.Encrypt(rand.Reader, policy, msg) if err != nil { t.Fatalf("encryption failed: %s", err) } attrs := Attributes{} attrs.FromMap(test.Attrs) sk, err := msk.KeyGen(rand.Reader, attrs) if err != nil { t.Fatalf("key generation failed: %s", err) } checkResults := func(ct []byte, sk AttributeKey, i int) { pt, err := sk.Decrypt(ct) if tests[i].Success { if err != nil { t.Errorf("decryption failed: %s", err) } if !bytes.Equal(pt, msg) { t.Errorf("expected %v, received %v", pt, msg) } } else { if err == nil { t.Errorf("decryption should have failed") } } } go checkResults(ct, sk, i) go checkResults(ct, sk, i) }) } } func TestEndToEndEncryption(t *testing.T) { var tests []TestCase buf, _ := os.ReadFile("testdata/policies.json") err := json.Unmarshal(buf, &tests) if err != nil { t.Fatal(err) } msg := []byte("must have the precious") for i, test := range tests { t.Run(fmt.Sprintf("TestEndToEndEncryption:#%d", i), func(t *testing.T) { pk, msk, err := Setup(rand.Reader) if err != nil { t.Fatal(err) } policy := Policy{} err = policy.FromString(test.Policy) if err != nil { t.Fatal(err) } ct, err := pk.Encrypt(rand.Reader, policy, msg) if err != nil { t.Fatalf("encryption failed: %s", err) } attrs := Attributes{} attrs.FromMap(test.Attrs) sk, err := msk.KeyGen(rand.Reader, attrs) if err != nil { t.Fatalf("key generation failed: %s", err) } npol := &Policy{} if err = npol.ExtractFromCiphertext(ct); err != nil { t.Fatalf("extraction failed: %s", err) } strpol := npol.String() npol2 := &Policy{} if err = npol2.FromString(strpol); err != nil { t.Fatalf("string %s didn't parse: %s", strpol, err) } sat := policy.Satisfaction(attrs) if sat != npol.Satisfaction(attrs) { t.Fatalf("extracted policy doesn't match original") } if sat != npol2.Satisfaction(attrs) { t.Fatalf("round tripped policy doesn't match original") } ctSat := attrs.CouldDecrypt(ct) pt, err := sk.Decrypt(ct) if test.Success { // test case should succeed if !sat { t.Fatalf("satisfaction failed") } if !ctSat { t.Fatalf("ciphertext satisfaction failed") } if err != nil { t.Fatalf("decryption failed: %s", err) } if !bytes.Equal(pt, msg) { t.Fatalf("expected %v, received %v", pt, msg) } } else { // test case should fail if sat { t.Fatal("satisfaction should have failed") } if ctSat { t.Fatal("ciphertext satisfaction should have failed") } if err == nil { t.Fatal("decryption should have failed") } } }) } } func TestMarshal(t *testing.T) { pk, msk, err := Setup(rand.Reader) if err != nil { t.Fatal(err) } data, err := pk.MarshalBinary() if err != nil { t.Fatal(err) } b := &PublicKey{} err = b.UnmarshalBinary(data) if err != nil { t.Fatal(err) } if !pk.Equal(b) { t.Fatal("PublicKey: failure to roundtrip") } data, err = msk.MarshalBinary() if err != nil { t.Fatal(err) } c := &SystemSecretKey{} err = c.UnmarshalBinary(data) if err != nil { t.Fatal(err) } if !msk.Equal(c) { t.Fatal("MasterSecretKey: failure to roundtrip") } attrs := Attributes{} attrs.FromMap(map[string]string{"occupation": "doctor", "country": "US", "age": "16"}) sk, err := msk.KeyGen(rand.Reader, attrs) if err != nil { t.Fatal(err) } data, err = sk.MarshalBinary() if err != nil { t.Fatal(err) } d := AttributeKey{} // don't use pointer to verify unmarshal works with both pointer and not err = d.UnmarshalBinary(data) if err != nil { t.Fatal(err) } if !sk.Equal(&d) { t.Fatal("SecretKey: failure to roundtrip") } } func TestPolicyMethods(t *testing.T) { policyStr := "(season: fall or season: winter) or (region: alaska and season: summer)" policy := Policy{} err := policy.FromString(policyStr) if err != nil { t.Fatal(err) } expected := map[string][]string{ "season": {"fall", "winter", "summer"}, "region": {"alaska"}, } received := policy.ExtractAttributeValuePairs() if len(expected) != len(received) { t.Fatal("diff lengths") } for k, vs := range expected { vs2, ok := received[k] if !ok { t.Fatalf("key %s not found in received map", k) } if len(vs) != len(vs2) { t.Fatalf("expected len: %d, received len: %d, for key %s", len(vs), len(vs2), k) } // compare each value for given key, order doesn't matter for _, v := range vs { flag := false for _, v2 := range vs2 { if v == v2 { flag = true break } } if !flag { t.Fatalf("expected and received values differ") } } } }