package smt
import (
"bytes"
"errors"
"hash"
)
// ErrBadProof is returned when an invalid Merkle proof is supplied.
var ErrBadProof = errors.New("bad proof")
// DeepSparseMerkleSubTree is a deep Sparse Merkle subtree for working on only a few leafs.
type DeepSparseMerkleSubTree struct {
*SparseMerkleTree
}
// NewDeepSparseMerkleSubTree creates a new deep Sparse Merkle subtree on an empty MapStore.
func NewDeepSparseMerkleSubTree(nodes, values MapStore, hasher hash.Hash, root []byte) *DeepSparseMerkleSubTree {
return &DeepSparseMerkleSubTree{
SparseMerkleTree: ImportSparseMerkleTree(nodes, values, hasher, root),
}
}
// AddBranch adds a branch to the tree.
// These branches are generated by smt.ProveForRoot.
// If the proof is invalid, a ErrBadProof is returned.
//
// If the leaf may be updated (e.g. during a state transition fraud proof),
// an updatable proof should be used. See SparseMerkleTree.ProveUpdatable.
func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []byte, value []byte) error {
result, updates := verifyProofWithUpdates(proof, dsmst.Root(), key, value, dsmst.th.hasher)
if !result {
return ErrBadProof
}
if !bytes.Equal(value, defaultValue) { // Membership proof.
if err := dsmst.values.Set(dsmst.th.path(key), value); err != nil {
return err
}
}
// Update nodes along branch
for _, update := range updates {
err := dsmst.nodes.Set(update[0], update[1])
if err != nil {
return err
}
}
// Update sibling node
if proof.SiblingData != nil {
if proof.SideNodes != nil && len(proof.SideNodes) > 0 {
err := dsmst.nodes.Set(proof.SideNodes[0], proof.SiblingData)
if err != nil {
return err
}
}
}
return nil
}
// GetDescend gets the value of a key from the tree by descending it.
// Use if a key was _not_ previously added with AddBranch, otherwise use Get.
// Errors if the key cannot be reached by descending.
func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) {
// Get tree's root
root := smt.Root()
if bytes.Equal(root, smt.th.placeholder()) {
// The tree is empty, return the default value.
return defaultValue, nil
}
path := smt.th.path(key)
currentHash := root
for i := 0; i < smt.depth(); i++ {
currentData, err := smt.nodes.Get(currentHash)
if err != nil {
return nil, err
} else if smt.th.isLeaf(currentData) {
// We've reached the end. Is this the actual leaf?
p, _ := smt.th.parseLeaf(currentData)
if !bytes.Equal(path, p) {
// Nope. Therefore the key is actually empty.
return defaultValue, nil
}
// Otherwise, yes. Return the value.
value, err := smt.values.Get(path)
if err != nil {
return nil, err
}
return value, nil
}
leftNode, rightNode := smt.th.parseNode(currentData)
if getBitAtFromMSB(path, i) == right {
currentHash = rightNode
} else {
currentHash = leftNode
}
if bytes.Equal(currentHash, smt.th.placeholder()) {
// We've hit a placeholder value; this is the end.
return defaultValue, nil
}
}
// The following lines of code should only be reached if the path is 256
// nodes high, which should be very unlikely if the underlying hash function
// is collision-resistant.
value, err := smt.values.Get(path)
if err != nil {
return nil, err
}
return value, nil
}
// HasDescend returns true if the value at the given key is non-default, false
// otherwise.
// Use if a key was _not_ previously added with AddBranch, otherwise use Has.
// Errors if the key cannot be reached by descending.
func (smt *SparseMerkleTree) HasDescend(key []byte) (bool, error) {
val, err := smt.GetDescend(key)
if err != nil {
return false, err
}
return !bytes.Equal(defaultValue, val), nil
}
package delete
import (
"bytes"
"crypto/sha256"
"github.com/celestiaorg/smt"
)
func Fuzz(data []byte) int {
if len(data) == 0 {
return -1
}
splits := bytes.Split(data, []byte("*"))
if len(splits) < 3 {
return -1
}
smn, smv := smt.NewSimpleMap(), smt.NewSimpleMap()
tree := smt.NewSparseMerkleTree(smn, smv, sha256.New())
for i := 0; i < len(splits)-1; i += 2 {
key, value := splits[i], splits[i+1]
tree.Update(key, value)
}
deleteKey := splits[len(splits)-1]
newRoot, err := tree.Delete(deleteKey)
if err != nil {
return 0
}
if len(newRoot) == 0 {
panic("newRoot is nil yet err==nil")
}
return 1
}
package fuzz
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"math"
"github.com/celestiaorg/smt"
)
func Fuzz(input []byte) int {
if len(input) < 100 {
return 0
}
smn, smv := smt.NewSimpleMap(), smt.NewSimpleMap()
tree := smt.NewSparseMerkleTree(smn, smv, sha256.New())
r := bytes.NewReader(input)
var keys [][]byte
key := func() []byte {
if readByte(r) < math.MaxUint8/2 {
k := make([]byte, readByte(r)/2)
r.Read(k)
keys = append(keys, k)
return k
}
if len(keys) == 0 {
return nil
}
return keys[int(readByte(r))%len(keys)]
}
for i := 0; r.Len() != 0; i++ {
b, err := r.ReadByte()
if err != nil {
continue
}
op := op(int(b) % int(Noop))
switch op {
case Get:
tree.Get(key())
case Update:
value := make([]byte, 32)
binary.BigEndian.PutUint64(value, uint64(i))
tree.Update(key(), value)
case Delete:
tree.Delete(key())
case Prove:
tree.Prove(key())
case Has:
tree.Has(key())
}
}
return 1
}
type op int
const (
Get op = iota
Update
Delete
Prove
Has
Noop
)
func readByte(r *bytes.Reader) byte {
b, err := r.ReadByte()
if err != nil {
return 0
}
return b
}
package smt
import (
"fmt"
)
// MapStore is a key-value store.
type MapStore interface {
Get(key []byte) ([]byte, error) // Get gets the value for a key.
Set(key []byte, value []byte) error // Set updates the value for a key.
Delete(key []byte) error // Delete deletes a key.
}
// InvalidKeyError is thrown when a key that does not exist is being accessed.
type InvalidKeyError struct {
Key []byte
}
func (e *InvalidKeyError) Error() string {
return fmt.Sprintf("invalid key: %x", e.Key)
}
// SimpleMap is a simple in-memory map.
type SimpleMap struct {
m map[string][]byte
}
// NewSimpleMap creates a new empty SimpleMap.
func NewSimpleMap() *SimpleMap {
return &SimpleMap{
m: make(map[string][]byte),
}
}
// Get gets the value for a key.
func (sm *SimpleMap) Get(key []byte) ([]byte, error) {
if value, ok := sm.m[string(key)]; ok {
return value, nil
}
return nil, &InvalidKeyError{Key: key}
}
// Set updates the value for a key.
func (sm *SimpleMap) Set(key []byte, value []byte) error {
sm.m[string(key)] = value
return nil
}
// Delete deletes a key.
func (sm *SimpleMap) Delete(key []byte) error {
_, ok := sm.m[string(key)]
if ok {
delete(sm.m, string(key))
return nil
}
return &InvalidKeyError{Key: key}
}
package smt
import (
"bytes"
"hash"
"math"
)
// SparseMerkleProof is a Merkle proof for an element in a SparseMerkleTree.
type SparseMerkleProof struct {
// SideNodes is an array of the sibling nodes leading up to the leaf of the proof.
SideNodes [][]byte
// NonMembershipLeafData is the data of the unrelated leaf at the position
// of the key being proven, in the case of a non-membership proof. For
// membership proofs, is nil.
NonMembershipLeafData []byte
// SiblingData is the data of the sibling node to the leaf being proven,
// required for updatable proofs. For unupdatable proofs, is nil.
SiblingData []byte
}
func (proof *SparseMerkleProof) sanityCheck(th *treeHasher) bool {
// Do a basic sanity check on the proof, so that a malicious proof cannot
// cause the verifier to fatally exit (e.g. due to an index out-of-range
// error) or cause a CPU DoS attack.
// Check that the number of supplied sidenodes does not exceed the maximum possible.
if len(proof.SideNodes) > th.pathSize()*8 ||
// Check that leaf data for non-membership proofs is the correct size.
(proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) != len(leafPrefix)+th.pathSize()+th.hasher.Size()) {
return false
}
// Check that all supplied sidenodes are the correct size.
for _, v := range proof.SideNodes {
if len(v) != th.hasher.Size() {
return false
}
}
// Check that the sibling data hashes to the first side node if not nil
if proof.SiblingData == nil || len(proof.SideNodes) == 0 {
return true
}
siblingHash := th.digest(proof.SiblingData)
return bytes.Equal(proof.SideNodes[0], siblingHash)
}
// SparseCompactMerkleProof is a compact Merkle proof for an element in a SparseMerkleTree.
type SparseCompactMerkleProof struct {
// SideNodes is an array of the sibling nodes leading up to the leaf of the proof.
SideNodes [][]byte
// NonMembershipLeafData is the data of the unrelated leaf at the position
// of the key being proven, in the case of a non-membership proof. For
// membership proofs, is nil.
NonMembershipLeafData []byte
// BitMask, in the case of a compact proof, is a bit mask of the sidenodes
// of the proof where an on-bit indicates that the sidenode at the bit's
// index is a placeholder. This is only set if the proof is compact.
BitMask []byte
// NumSideNodes, in the case of a compact proof, indicates the number of
// sidenodes in the proof when decompacted. This is only set if the proof is compact.
NumSideNodes int
// SiblingData is the data of the sibling node to the leaf being proven,
// required for updatable proofs. For unupdatable proofs, is nil.
SiblingData []byte
}
func (proof *SparseCompactMerkleProof) sanityCheck(th *treeHasher) bool {
// Do a basic sanity check on the proof on the fields of the proof specific to
// the compact proof only.
//
// When the proof is de-compacted and verified, the sanity check for the
// de-compacted proof should be executed.
// Compact proofs: check that NumSideNodes is within the right range.
if proof.NumSideNodes < 0 || proof.NumSideNodes > th.pathSize()*8 ||
// Compact proofs: check that the length of the bit mask is as expected
// according to NumSideNodes.
len(proof.BitMask) != int(math.Ceil(float64(proof.NumSideNodes)/float64(8))) ||
// Compact proofs: check that the correct number of sidenodes have been
// supplied according to the bit mask.
(proof.NumSideNodes > 0 && len(proof.SideNodes) != proof.NumSideNodes-countSetBits(proof.BitMask)) {
return false
}
return true
}
// VerifyProof verifies a Merkle proof.
func VerifyProof(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) bool {
result, _ := verifyProofWithUpdates(proof, root, key, value, hasher)
return result
}
func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) (bool, [][][]byte) {
th := newTreeHasher(hasher)
path := th.path(key)
if !proof.sanityCheck(th) {
return false, nil
}
var updates [][][]byte
// Determine what the leaf hash should be.
var currentHash, currentData []byte
if bytes.Equal(value, defaultValue) { // Non-membership proof.
if proof.NonMembershipLeafData == nil { // Leaf is a placeholder value.
currentHash = th.placeholder()
} else { // Leaf is an unrelated leaf.
actualPath, valueHash := th.parseLeaf(proof.NonMembershipLeafData)
if bytes.Equal(actualPath, path) {
// This is not an unrelated leaf; non-membership proof failed.
return false, nil
}
currentHash, currentData = th.digestLeaf(actualPath, valueHash)
update := make([][]byte, 2)
update[0], update[1] = currentHash, currentData
updates = append(updates, update)
}
} else { // Membership proof.
valueHash := th.digest(value)
currentHash, currentData = th.digestLeaf(path, valueHash)
update := make([][]byte, 2)
update[0], update[1] = currentHash, currentData
updates = append(updates, update)
}
// Recompute root.
for i := 0; i < len(proof.SideNodes); i++ {
node := make([]byte, th.pathSize())
copy(node, proof.SideNodes[i])
if getBitAtFromMSB(path, len(proof.SideNodes)-1-i) == right {
currentHash, currentData = th.digestNode(node, currentHash)
} else {
currentHash, currentData = th.digestNode(currentHash, node)
}
update := make([][]byte, 2)
update[0], update[1] = currentHash, currentData
updates = append(updates, update)
}
return bytes.Equal(currentHash, root), updates
}
// VerifyCompactProof verifies a compacted Merkle proof.
func VerifyCompactProof(proof SparseCompactMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) bool {
decompactedProof, err := DecompactProof(proof, hasher)
if err != nil {
return false
}
return VerifyProof(decompactedProof, root, key, value, hasher)
}
// CompactProof compacts a proof, to reduce its size.
func CompactProof(proof SparseMerkleProof, hasher hash.Hash) (SparseCompactMerkleProof, error) {
th := newTreeHasher(hasher)
if !proof.sanityCheck(th) {
return SparseCompactMerkleProof{}, ErrBadProof
}
bitMask := emptyBytes(int(math.Ceil(float64(len(proof.SideNodes)) / float64(8))))
var compactedSideNodes [][]byte
for i := 0; i < len(proof.SideNodes); i++ {
node := make([]byte, th.hasher.Size())
copy(node, proof.SideNodes[i])
if bytes.Equal(node, th.placeholder()) {
setBitAtFromMSB(bitMask, i)
} else {
compactedSideNodes = append(compactedSideNodes, node)
}
}
return SparseCompactMerkleProof{
SideNodes: compactedSideNodes,
NonMembershipLeafData: proof.NonMembershipLeafData,
BitMask: bitMask,
NumSideNodes: len(proof.SideNodes),
SiblingData: proof.SiblingData,
}, nil
}
// DecompactProof decompacts a proof, so that it can be used for VerifyProof.
func DecompactProof(proof SparseCompactMerkleProof, hasher hash.Hash) (SparseMerkleProof, error) {
th := newTreeHasher(hasher)
if !proof.sanityCheck(th) {
return SparseMerkleProof{}, ErrBadProof
}
decompactedSideNodes := make([][]byte, proof.NumSideNodes)
position := 0
for i := 0; i < proof.NumSideNodes; i++ {
if getBitAtFromMSB(proof.BitMask, i) == 1 {
decompactedSideNodes[i] = th.placeholder()
} else {
decompactedSideNodes[i] = proof.SideNodes[position]
position++
}
}
return SparseMerkleProof{
SideNodes: decompactedSideNodes,
NonMembershipLeafData: proof.NonMembershipLeafData,
SiblingData: proof.SiblingData,
}, nil
}
// Package smt implements a Sparse Merkle tree.
package smt
import (
"bytes"
"errors"
"hash"
)
const (
right = 1
)
var defaultValue = []byte{}
var errKeyAlreadyEmpty = errors.New("key already empty")
// SparseMerkleTree is a Sparse Merkle tree.
type SparseMerkleTree struct {
th treeHasher
nodes, values MapStore
root []byte
}
// NewSparseMerkleTree creates a new Sparse Merkle tree on an empty MapStore.
func NewSparseMerkleTree(nodes, values MapStore, hasher hash.Hash, options ...Option) *SparseMerkleTree {
smt := SparseMerkleTree{
th: *newTreeHasher(hasher),
nodes: nodes,
values: values,
}
for _, option := range options {
option(&smt)
}
smt.SetRoot(smt.th.placeholder())
return &smt
}
// ImportSparseMerkleTree imports a Sparse Merkle tree from a non-empty MapStore.
func ImportSparseMerkleTree(nodes, values MapStore, hasher hash.Hash, root []byte) *SparseMerkleTree {
smt := SparseMerkleTree{
th: *newTreeHasher(hasher),
nodes: nodes,
values: values,
root: root,
}
return &smt
}
// Root gets the root of the tree.
func (smt *SparseMerkleTree) Root() []byte {
return smt.root
}
// SetRoot sets the root of the tree.
func (smt *SparseMerkleTree) SetRoot(root []byte) {
smt.root = root
}
func (smt *SparseMerkleTree) depth() int {
return smt.th.pathSize() * 8
}
// Get gets the value of a key from the tree.
func (smt *SparseMerkleTree) Get(key []byte) ([]byte, error) {
// Get tree's root
root := smt.Root()
if bytes.Equal(root, smt.th.placeholder()) {
// The tree is empty, return the default value.
return defaultValue, nil
}
path := smt.th.path(key)
value, err := smt.values.Get(path)
if err != nil {
var invalidKeyError *InvalidKeyError
if errors.As(err, &invalidKeyError) {
// If key isn't found, return default value
return defaultValue, nil
} else {
// Otherwise percolate up any other error
return nil, err
}
}
return value, nil
}
// Has returns true if the value at the given key is non-default, false
// otherwise.
func (smt *SparseMerkleTree) Has(key []byte) (bool, error) {
val, err := smt.Get(key)
return !bytes.Equal(defaultValue, val), err
}
// Update sets a new value for a key in the tree, and sets and returns the new root of the tree.
func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) {
newRoot, err := smt.UpdateForRoot(key, value, smt.Root())
if err != nil {
return nil, err
}
smt.SetRoot(newRoot)
return newRoot, nil
}
// Delete deletes a value from tree. It returns the new root of the tree.
func (smt *SparseMerkleTree) Delete(key []byte) ([]byte, error) {
return smt.Update(key, defaultValue)
}
// UpdateForRoot sets a new value for a key in the tree at a specific root, and returns the new root.
func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte) ([]byte, error) {
path := smt.th.path(key)
sideNodes, pathNodes, oldLeafData, _, err := smt.sideNodesForRoot(path, root, false)
if err != nil {
return nil, err
}
var newRoot []byte
if bytes.Equal(value, defaultValue) {
// Delete operation.
newRoot, err = smt.deleteWithSideNodes(path, sideNodes, pathNodes, oldLeafData)
if errors.Is(err, errKeyAlreadyEmpty) {
// This key is already empty; return the old root.
return root, nil
}
if err := smt.values.Delete(path); err != nil {
return nil, err
}
} else {
// Insert or update operation.
newRoot, err = smt.updateWithSideNodes(path, value, sideNodes, pathNodes, oldLeafData)
}
return newRoot, err
}
// DeleteForRoot deletes a value from tree at a specific root. It returns the new root of the tree.
func (smt *SparseMerkleTree) DeleteForRoot(key, root []byte) ([]byte, error) {
return smt.UpdateForRoot(key, defaultValue, root)
}
func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte, pathNodes [][]byte, oldLeafData []byte) ([]byte, error) {
if bytes.Equal(pathNodes[0], smt.th.placeholder()) {
// This key is already empty as it is a placeholder; return an error.
return nil, errKeyAlreadyEmpty
}
actualPath, _ := smt.th.parseLeaf(oldLeafData)
if !bytes.Equal(path, actualPath) {
// This key is already empty as a different key was found its place; return an error.
return nil, errKeyAlreadyEmpty
}
// All nodes above the deleted leaf are now orphaned
for _, node := range pathNodes {
if err := smt.nodes.Delete(node); err != nil {
return nil, err
}
}
var currentHash, currentData []byte
nonPlaceholderReached := false
for i, sideNode := range sideNodes {
if currentData == nil {
sideNodeValue, err := smt.nodes.Get(sideNode)
if err != nil {
return nil, err
}
if smt.th.isLeaf(sideNodeValue) {
// This is the leaf sibling that needs to be bubbled up the tree.
currentHash = sideNode
currentData = sideNode
continue
} else {
// This is the node sibling that needs to be left in its place.
currentData = smt.th.placeholder()
nonPlaceholderReached = true
}
}
if !nonPlaceholderReached && bytes.Equal(sideNode, smt.th.placeholder()) {
// We found another placeholder sibling node, keep going up the
// tree until we find the first sibling that is not a placeholder.
continue
} else if !nonPlaceholderReached {
// We found the first sibling node that is not a placeholder, it is
// time to insert our leaf sibling node here.
nonPlaceholderReached = true
}
if getBitAtFromMSB(path, len(sideNodes)-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
}
if err := smt.nodes.Set(currentHash, currentData); err != nil {
return nil, err
}
currentData = currentHash
}
if currentHash == nil {
// The tree is empty; return placeholder value as root.
currentHash = smt.th.placeholder()
}
return currentHash, nil
}
func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, sideNodes [][]byte, pathNodes [][]byte, oldLeafData []byte) ([]byte, error) {
valueHash := smt.th.digest(value)
currentHash, currentData := smt.th.digestLeaf(path, valueHash)
if err := smt.nodes.Set(currentHash, currentData); err != nil {
return nil, err
}
currentData = currentHash
// If the leaf node that sibling nodes lead to has a different actual path
// than the leaf node being updated, we need to create an intermediate node
// with this leaf node and the new leaf node as children.
//
// First, get the number of bits that the paths of the two leaf nodes share
// in common as a prefix.
var commonPrefixCount int
var oldValueHash []byte
if bytes.Equal(pathNodes[0], smt.th.placeholder()) {
commonPrefixCount = smt.depth()
} else {
var actualPath []byte
actualPath, oldValueHash = smt.th.parseLeaf(oldLeafData)
commonPrefixCount = countCommonPrefix(path, actualPath)
}
if commonPrefixCount != smt.depth() {
if getBitAtFromMSB(path, commonPrefixCount) == right {
currentHash, currentData = smt.th.digestNode(pathNodes[0], currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, pathNodes[0])
}
err := smt.nodes.Set(currentHash, currentData)
if err != nil {
return nil, err
}
currentData = currentHash
} else if oldValueHash != nil {
// Short-circuit if the same value is being set
if bytes.Equal(oldValueHash, valueHash) {
return smt.root, nil
}
// If an old leaf exists, remove it
if err := smt.nodes.Delete(pathNodes[0]); err != nil {
return nil, err
}
if err := smt.values.Delete(path); err != nil {
return nil, err
}
}
// All remaining path nodes are orphaned
for i := 1; i < len(pathNodes); i++ {
if err := smt.nodes.Delete(pathNodes[i]); err != nil {
return nil, err
}
}
// The offset from the bottom of the tree to the start of the side nodes.
// Note: i-offsetOfSideNodes is the index into sideNodes[]
offsetOfSideNodes := smt.depth() - len(sideNodes)
for i := 0; i < smt.depth(); i++ {
var sideNode []byte
if i-offsetOfSideNodes < 0 || sideNodes[i-offsetOfSideNodes] == nil {
if commonPrefixCount != smt.depth() && commonPrefixCount > smt.depth()-1-i {
// If there are no sidenodes at this height, but the number of
// bits that the paths of the two leaf nodes share in common is
// greater than this depth, then we need to build up the tree
// to this depth with placeholder values at siblings.
sideNode = smt.th.placeholder()
} else {
continue
}
} else {
sideNode = sideNodes[i-offsetOfSideNodes]
}
if getBitAtFromMSB(path, smt.depth()-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
}
err := smt.nodes.Set(currentHash, currentData)
if err != nil {
return nil, err
}
currentData = currentHash
}
if err := smt.values.Set(path, value); err != nil {
return nil, err
}
return currentHash, nil
}
// Get all the sibling nodes (sidenodes) for a given path from a given root.
// Returns an array of sibling nodes, the leaf hash found at that path, the
// leaf data, and the sibling data.
//
// If the leaf is a placeholder, the leaf data is nil.
func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte, getSiblingData bool) ([][]byte, [][]byte, []byte, []byte, error) {
// Side nodes for the path. Nodes are inserted in reverse order, then the
// slice is reversed at the end.
sideNodes := make([][]byte, 0, smt.depth())
pathNodes := make([][]byte, 0, smt.depth()+1)
pathNodes = append(pathNodes, root)
if bytes.Equal(root, smt.th.placeholder()) {
// If the root is a placeholder, there are no sidenodes to return.
// Let the "actual path" be the input path.
return sideNodes, pathNodes, nil, nil, nil
}
currentData, err := smt.nodes.Get(root)
if err != nil {
return nil, nil, nil, nil, err
} else if smt.th.isLeaf(currentData) {
// If the root is a leaf, there are also no sidenodes to return.
return sideNodes, pathNodes, currentData, nil, nil
}
var nodeHash []byte
var sideNode []byte
var siblingData []byte
for i := 0; i < smt.depth(); i++ {
leftNode, rightNode := smt.th.parseNode(currentData)
// Get sidenode depending on whether the path bit is on or off.
if getBitAtFromMSB(path, i) == right {
sideNode = leftNode
nodeHash = rightNode
} else {
sideNode = rightNode
nodeHash = leftNode
}
sideNodes = append(sideNodes, sideNode)
pathNodes = append(pathNodes, nodeHash)
if bytes.Equal(nodeHash, smt.th.placeholder()) {
// If the node is a placeholder, we've reached the end.
currentData = nil
break
}
currentData, err = smt.nodes.Get(nodeHash)
if err != nil {
return nil, nil, nil, nil, err
} else if smt.th.isLeaf(currentData) {
// If the node is a leaf, we've reached the end.
break
}
}
if getSiblingData {
siblingData, err = smt.nodes.Get(sideNode)
if err != nil {
return nil, nil, nil, nil, err
}
}
return reverseByteSlices(sideNodes), reverseByteSlices(pathNodes), currentData, siblingData, nil
}
// Prove generates a Merkle proof for a key against the current root.
//
// This proof can be used for read-only applications, but should not be used if
// the leaf may be updated (e.g. in a state transition fraud proof). For
// updatable proofs, see ProveUpdatable.
func (smt *SparseMerkleTree) Prove(key []byte) (SparseMerkleProof, error) {
proof, err := smt.ProveForRoot(key, smt.Root())
return proof, err
}
// ProveForRoot generates a Merkle proof for a key, against a specific node.
// This is primarily useful for generating Merkle proofs for subtrees.
//
// This proof can be used for read-only applications, but should not be used if
// the leaf may be updated (e.g. in a state transition fraud proof). For
// updatable proofs, see ProveUpdatableForRoot.
func (smt *SparseMerkleTree) ProveForRoot(key []byte, root []byte) (SparseMerkleProof, error) {
return smt.doProveForRoot(key, root, false)
}
// ProveUpdatable generates an updatable Merkle proof for a key against the current root.
func (smt *SparseMerkleTree) ProveUpdatable(key []byte) (SparseMerkleProof, error) {
proof, err := smt.ProveUpdatableForRoot(key, smt.Root())
return proof, err
}
// ProveUpdatableForRoot generates an updatable Merkle proof for a key, against a specific node.
// This is primarily useful for generating Merkle proofs for subtrees.
func (smt *SparseMerkleTree) ProveUpdatableForRoot(key []byte, root []byte) (SparseMerkleProof, error) {
return smt.doProveForRoot(key, root, true)
}
func (smt *SparseMerkleTree) doProveForRoot(key []byte, root []byte, isUpdatable bool) (SparseMerkleProof, error) {
path := smt.th.path(key)
sideNodes, pathNodes, leafData, siblingData, err := smt.sideNodesForRoot(path, root, isUpdatable)
if err != nil {
return SparseMerkleProof{}, err
}
var nonEmptySideNodes [][]byte
for _, v := range sideNodes {
if v != nil {
nonEmptySideNodes = append(nonEmptySideNodes, v)
}
}
// Deal with non-membership proofs. If the leaf hash is the placeholder
// value, we do not need to add anything else to the proof.
var nonMembershipLeafData []byte
if !bytes.Equal(pathNodes[0], smt.th.placeholder()) {
actualPath, _ := smt.th.parseLeaf(leafData)
if !bytes.Equal(actualPath, path) {
// This is a non-membership proof that involves showing a different leaf.
// Add the leaf data to the proof.
nonMembershipLeafData = leafData
}
}
proof := SparseMerkleProof{
SideNodes: nonEmptySideNodes,
NonMembershipLeafData: nonMembershipLeafData,
SiblingData: siblingData,
}
return proof, err
}
// ProveCompact generates a compacted Merkle proof for a key against the current root.
func (smt *SparseMerkleTree) ProveCompact(key []byte) (SparseCompactMerkleProof, error) {
proof, err := smt.ProveCompactForRoot(key, smt.Root())
return proof, err
}
// ProveCompactForRoot generates a compacted Merkle proof for a key, at a specific root.
func (smt *SparseMerkleTree) ProveCompactForRoot(key []byte, root []byte) (SparseCompactMerkleProof, error) {
proof, err := smt.ProveForRoot(key, root)
if err != nil {
return SparseCompactMerkleProof{}, err
}
compactedProof, err := CompactProof(proof, smt.th.hasher)
return compactedProof, err
}
package smt
import (
"bytes"
"hash"
)
var leafPrefix = []byte{0}
var nodePrefix = []byte{1}
type treeHasher struct {
hasher hash.Hash
zeroValue []byte
}
func newTreeHasher(hasher hash.Hash) *treeHasher {
th := treeHasher{hasher: hasher}
th.zeroValue = make([]byte, th.pathSize())
return &th
}
func (th *treeHasher) digest(data []byte) []byte {
th.hasher.Write(data)
sum := th.hasher.Sum(nil)
th.hasher.Reset()
return sum
}
func (th *treeHasher) path(key []byte) []byte {
return th.digest(key)
}
func (th *treeHasher) digestLeaf(path []byte, leafData []byte) ([]byte, []byte) {
value := make([]byte, 0, len(leafPrefix)+len(path)+len(leafData))
value = append(value, leafPrefix...)
value = append(value, path...)
value = append(value, leafData...)
th.hasher.Write(value)
sum := th.hasher.Sum(nil)
th.hasher.Reset()
return sum, value
}
func (th *treeHasher) parseLeaf(data []byte) ([]byte, []byte) {
return data[len(leafPrefix) : th.pathSize()+len(leafPrefix)], data[len(leafPrefix)+th.pathSize():]
}
func (th *treeHasher) isLeaf(data []byte) bool {
return bytes.Equal(data[:len(leafPrefix)], leafPrefix)
}
func (th *treeHasher) digestNode(leftData []byte, rightData []byte) ([]byte, []byte) {
value := make([]byte, 0, len(nodePrefix)+len(leftData)+len(rightData))
value = append(value, nodePrefix...)
value = append(value, leftData...)
value = append(value, rightData...)
th.hasher.Write(value)
sum := th.hasher.Sum(nil)
th.hasher.Reset()
return sum, value
}
func (th *treeHasher) parseNode(data []byte) ([]byte, []byte) {
return data[len(nodePrefix) : th.pathSize()+len(nodePrefix)], data[len(nodePrefix)+th.pathSize():]
}
func (th *treeHasher) pathSize() int {
return th.hasher.Size()
}
func (th *treeHasher) placeholder() []byte {
return th.zeroValue
}
package smt
// getBitAtFromMSB gets the bit at an offset from the most significant bit
func getBitAtFromMSB(data []byte, position int) int {
if int(data[position/8])&(1<<(8-1-uint(position)%8)) > 0 {
return 1
}
return 0
}
// setBitAtFromMSB sets the bit at an offset from the most significant bit
func setBitAtFromMSB(data []byte, position int) {
n := int(data[position/8])
n |= 1 << (8 - 1 - uint(position)%8)
data[position/8] = byte(n)
}
func countSetBits(data []byte) int {
count := 0
for i := 0; i < len(data)*8; i++ {
if getBitAtFromMSB(data, i) == 1 {
count++
}
}
return count
}
func countCommonPrefix(data1 []byte, data2 []byte) int {
count := 0
for i := 0; i < len(data1)*8; i++ {
if getBitAtFromMSB(data1, i) == getBitAtFromMSB(data2, i) {
count++
} else {
break
}
}
return count
}
func emptyBytes(length int) []byte {
b := make([]byte, length)
return b
}
func reverseByteSlices(slices [][]byte) [][]byte {
for left, right := 0, len(slices)-1; left < right; left, right = left+1, right-1 {
slices[left], slices[right] = slices[right], slices[left]
}
return slices
}