@@ -164,10 +164,12 @@ impl PyDictTracker {
164
164
mod tests {
165
165
use crate :: { ids:: PyIds , memory:: PyMemory , utils:: to_vm_error, vm_core:: PyVM } ;
166
166
use cairo_rs:: {
167
+ bigint,
167
168
hint_processor:: hint_processor_definition:: HintReference ,
168
169
serde:: deserialize_program:: { ApTracking , Member } ,
169
170
types:: relocatable:: Relocatable ,
170
171
types:: { instruction:: Register , relocatable:: MaybeRelocatable } ,
172
+ vm:: errors:: vm_errors:: VirtualMachineError ,
171
173
} ;
172
174
use num_bigint:: { BigInt , Sign } ;
173
175
use pyo3:: { types:: PyDict , PyCell } ;
@@ -439,6 +441,114 @@ assert dict_tracker.data[1] == 22
439
441
} ) ;
440
442
}
441
443
444
+ #[ test]
445
+ fn tracker_read_and_write_invalid_key ( ) {
446
+ Python :: with_gil ( |py| {
447
+ let vm = PyVM :: new (
448
+ BigInt :: new ( Sign :: Plus , vec ! [ 1 , 0 , 0 , 0 , 0 , 0 , 17 , 134217728 ] ) ,
449
+ false ,
450
+ ) ;
451
+ for _ in 0 ..2 {
452
+ vm. vm . borrow_mut ( ) . add_memory_segment ( ) ;
453
+ }
454
+
455
+ let dict_manager = PyDictManager :: default ( ) ;
456
+
457
+ let segment_manager = PySegmentManager :: new ( & vm, PyMemory :: new ( & vm) ) ;
458
+
459
+ //Create references
460
+ let mut references = HashMap :: new ( ) ;
461
+ references. insert (
462
+ String :: from ( "dict" ) ,
463
+ HintReference {
464
+ register : Some ( Register :: FP ) ,
465
+ offset1 : 0 ,
466
+ offset2 : 0 ,
467
+ inner_dereference : false ,
468
+ ap_tracking_data : None ,
469
+ immediate : None ,
470
+ dereference : true ,
471
+ cairo_type : Some ( String :: from ( "DictAccess*" ) ) ,
472
+ } ,
473
+ ) ;
474
+ // Create ids.a
475
+ references. insert ( String :: from ( "a" ) , HintReference :: new_simple ( 1 ) ) ;
476
+
477
+ //Insert ids.a into memory
478
+ vm. vm
479
+ . borrow_mut ( )
480
+ . insert_value (
481
+ & Relocatable :: from ( ( 1 , 1 ) ) ,
482
+ & MaybeRelocatable :: from ( ( 128 , 64 ) ) ,
483
+ )
484
+ . unwrap ( ) ;
485
+
486
+ let mut struct_types: HashMap < String , HashMap < String , Member > > = HashMap :: new ( ) ;
487
+ struct_types. insert ( String :: from ( "DictAccess" ) , HashMap :: new ( ) ) ;
488
+
489
+ let ids = PyIds :: new (
490
+ & vm,
491
+ & references,
492
+ & ApTracking :: default ( ) ,
493
+ & HashMap :: new ( ) ,
494
+ Rc :: new ( struct_types) ,
495
+ ) ;
496
+
497
+ let globals = PyDict :: new ( py) ;
498
+ globals
499
+ . set_item ( "dict_manager" , PyCell :: new ( py, dict_manager) . unwrap ( ) )
500
+ . unwrap ( ) ;
501
+ globals
502
+ . set_item ( "ids" , PyCell :: new ( py, ids) . unwrap ( ) )
503
+ . unwrap ( ) ;
504
+ globals
505
+ . set_item ( "segments" , PyCell :: new ( py, segment_manager) . unwrap ( ) )
506
+ . unwrap ( ) ;
507
+
508
+ let code = r#"
509
+ initial_dict = { 1: 2, 4: 8, 16: 32 }
510
+ ids.dict = dict_manager.new_dict(segments, initial_dict)
511
+ dict_tracker = dict_manager.get_tracker(ids.dict)
512
+ dict_tracker.data[3]
513
+ "# ;
514
+
515
+ let py_result = py. run ( code, Some ( globals) , None ) ;
516
+
517
+ assert_eq ! (
518
+ py_result. map_err( to_vm_error) ,
519
+ Err ( to_vm_error( to_py_error(
520
+ VirtualMachineError :: NoValueForKey ( bigint!( 3 ) )
521
+ ) ) ) ,
522
+ ) ;
523
+
524
+ let code = r#"
525
+ dict_tracker = dict_manager.get_tracker(ids.dict)
526
+ dict_tracker.data[ids.a]
527
+ "# ;
528
+
529
+ let py_result = py. run ( code, Some ( globals) , None ) ;
530
+ let key = PyMaybeRelocatable :: from ( PyRelocatable :: from ( ( 128 , 64 ) ) ) ;
531
+
532
+ assert_eq ! (
533
+ py_result. map_err( to_vm_error) ,
534
+ Err ( PyKeyError :: new_err( key. to_object( py) ) ) . map_err( to_vm_error) ,
535
+ ) ;
536
+
537
+ let code = r#"
538
+ dict_tracker = dict_manager.get_tracker(ids.dict)
539
+ dict_tracker.data[ids.a] = 5
540
+ "# ;
541
+
542
+ let py_result = py. run ( code, Some ( globals) , None ) ;
543
+ let key = PyMaybeRelocatable :: from ( PyRelocatable :: from ( ( 128 , 64 ) ) ) ;
544
+
545
+ assert_eq ! (
546
+ py_result. map_err( to_vm_error) ,
547
+ Err ( PyKeyError :: new_err( key. to_object( py) ) ) . map_err( to_vm_error) ,
548
+ ) ;
549
+ } ) ;
550
+ }
551
+
442
552
#[ test]
443
553
fn tracker_get_and_set_current_ptr ( ) {
444
554
Python :: with_gil ( |py| {
@@ -524,4 +634,105 @@ assert dict_tracker.current_ptr == ids.end_ptr
524
634
assert_eq ! ( py_result. map_err( to_vm_error) , Ok ( ( ) ) ) ;
525
635
} ) ;
526
636
}
637
+
638
+ #[ test]
639
+ fn manager_get_tracker_invalid_dict_ptr ( ) {
640
+ Python :: with_gil ( |py| {
641
+ let vm = PyVM :: new (
642
+ BigInt :: new ( Sign :: Plus , vec ! [ 1 , 0 , 0 , 0 , 0 , 0 , 17 , 134217728 ] ) ,
643
+ false ,
644
+ ) ;
645
+ for _ in 0 ..2 {
646
+ vm. vm . borrow_mut ( ) . add_memory_segment ( ) ;
647
+ }
648
+
649
+ let dict_manager = PyDictManager :: default ( ) ;
650
+
651
+ let segment_manager = PySegmentManager :: new ( & vm, PyMemory :: new ( & vm) ) ;
652
+
653
+ //Create references
654
+ let mut references = HashMap :: new ( ) ;
655
+ references. insert (
656
+ String :: from ( "dict" ) ,
657
+ HintReference {
658
+ register : Some ( Register :: FP ) ,
659
+ offset1 : 0 ,
660
+ offset2 : 0 ,
661
+ inner_dereference : false ,
662
+ ap_tracking_data : None ,
663
+ immediate : None ,
664
+ dereference : true ,
665
+ cairo_type : Some ( String :: from ( "DictAccess*" ) ) ,
666
+ } ,
667
+ ) ;
668
+ references. insert (
669
+ String :: from ( "no_dict" ) ,
670
+ HintReference {
671
+ register : Some ( Register :: FP ) ,
672
+ offset1 : 0 ,
673
+ offset2 : 0 ,
674
+ inner_dereference : false ,
675
+ ap_tracking_data : None ,
676
+ immediate : None ,
677
+ dereference : true ,
678
+ cairo_type : Some ( String :: from ( "DictAccess" ) ) ,
679
+ } ,
680
+ ) ;
681
+
682
+ let mut struct_types: HashMap < String , HashMap < String , Member > > = HashMap :: new ( ) ;
683
+ struct_types. insert ( String :: from ( "DictAccess" ) , HashMap :: new ( ) ) ;
684
+
685
+ let ids = PyIds :: new (
686
+ & vm,
687
+ & references,
688
+ & ApTracking :: default ( ) ,
689
+ & HashMap :: new ( ) ,
690
+ Rc :: new ( struct_types) ,
691
+ ) ;
692
+
693
+ let globals = PyDict :: new ( py) ;
694
+ globals
695
+ . set_item ( "dict_manager" , PyCell :: new ( py, dict_manager) . unwrap ( ) )
696
+ . unwrap ( ) ;
697
+ globals
698
+ . set_item ( "ids" , PyCell :: new ( py, ids) . unwrap ( ) )
699
+ . unwrap ( ) ;
700
+ globals
701
+ . set_item ( "segments" , PyCell :: new ( py, segment_manager) . unwrap ( ) )
702
+ . unwrap ( ) ;
703
+
704
+ let code = r#"
705
+ ids.dict = dict_manager.new_dict(segments, {})
706
+ dict_tracker = dict_manager.get_tracker(ids.no_dict)
707
+ "# ;
708
+
709
+ let py_result = py. run ( code, Some ( globals) , None ) ;
710
+
711
+ assert_eq ! (
712
+ py_result. map_err( to_vm_error) ,
713
+ Err ( to_vm_error( to_py_error(
714
+ VirtualMachineError :: NoDictTracker ( vm. vm. borrow( ) . get_fp( ) . segment_index) ,
715
+ ) ) ) ,
716
+ ) ;
717
+
718
+ let code = r#"
719
+ dict_tracker = dict_manager.get_tracker(ids.dict)
720
+ dict_tracker.current_ptr = dict_tracker.current_ptr + 3
721
+
722
+ dict_tracker = dict_manager.get_tracker(ids.dict)
723
+ "# ;
724
+
725
+ let py_result = py. run ( code, Some ( globals) , None ) ;
726
+
727
+ assert_eq ! (
728
+ py_result. map_err( to_vm_error) ,
729
+ Err ( to_vm_error( to_py_error(
730
+ VirtualMachineError :: MismatchedDictPtr (
731
+ Relocatable :: from( ( 2 , 3 ) ) ,
732
+ Relocatable :: from( ( 2 , 0 ) ) ,
733
+ ) ,
734
+ ) ) ) ,
735
+ ) ;
736
+ } ) ;
737
+ }
527
738
}
0 commit comments