diff --git a/contracts/implementation.sol b/contracts/implementation.sol index d3912f0..e023af4 100644 --- a/contracts/implementation.sol +++ b/contracts/implementation.sol @@ -24,6 +24,14 @@ contract PartialMerkleTreeImplementation { return tree.get(key); } + function safeGet(bytes key) public view returns (bytes) { + return tree.safeGet(key); + } + + function doesInclude(bytes key) public view returns (bool) { + return tree.doesInclude(key); + } + function getValue(bytes32 hash) public view returns (bytes) { return tree.values[hash]; } diff --git a/contracts/tree.sol b/contracts/tree.sol index ee9861d..ec927a5 100644 --- a/contracts/tree.sol +++ b/contracts/tree.sol @@ -88,6 +88,18 @@ library PartialMerkleTree { return getValue(tree, _findNode(tree, key)); } + function safeGet(Tree storage tree, bytes key) internal view returns (bytes value) { + bytes32 valueHash = _findNode(tree, key); + require(valueHash != bytes32(0)); + value = getValue(tree, valueHash); + require(valueHash == keccak256(value)); + } + + function doesInclude(Tree storage tree, bytes key) internal view returns (bool) { + bytes32 valueHash = _findNode(tree, key); + return (valueHash != bytes32(0)); + } + function getValue(Tree storage tree, bytes32 valueHash) internal view returns (bytes) { return tree.values[valueHash]; } diff --git a/test/PartialMerkleTree.Test.js b/test/PartialMerkleTree.Test.js index 3070fba..33ea17f 100644 --- a/test/PartialMerkleTree.Test.js +++ b/test/PartialMerkleTree.Test.js @@ -142,6 +142,30 @@ contract('PartialMerkleTree', async ([_, primary, nonPrimary]) => { assert.equal(web3.toUtf8(await tree.get('foo')), 'bar') }) }) + + describe('safeGet()', async () => { + it('should return stored value for the given key', async () => { + await tree.insert('foo', 'bar', { from: primary }) + assert.equal(web3.toUtf8(await tree.get('foo')), 'bar') + }) + it('should throw if the given key is not included', async () => { + await tree.insert('foo', 'bar', { from: primary }) + try { + await tree.get('fuz') + assert.fail('Did not reverted') + } catch (e) { + assert.ok('Reverted successfully') + } + }) + }) + + describe('doesInclude()', async () => { + it('should return boolean whether the tree includes the given key or not', async () => { + await tree.insert('foo', 'bar', { from: primary }) + assert.equal(await tree.doesInclude('foo'), true) + assert.equal(await tree.doesInclude('fuz'), false) + }) + }) }) context('We can reenact merkle tree transformation by submitting only referred siblings instead of submitting all nodes', async () => { @@ -166,19 +190,19 @@ contract('PartialMerkleTree', async ([_, primary, nonPrimary]) => { siblingsForKey1 = proof[1] }) - it('should start with same root hash by initialization', async()=> { + it('should start with same root hash by initialization', async () => { //initilaze with the first root hash await treeB.initialize(firstPhaseOfTreeA) assert.equal(await treeB.getRootHash(), firstPhaseOfTreeA) }) - it('should not change root after committing branch data', async ()=> { + it('should not change root after committing branch data', async () => { // commit branch data await treeB.commitBranch('key1', referredValueForKey1, branchMaskForKey1, siblingsForKey1) assert.equal(await treeB.getRootHash(), firstPhaseOfTreeA) }) - it('should be able to return proof data', async ()=> { + it('should be able to return proof data', async () => { // commit branch data await treeB.getProof('key1') })