Skip to content

Commit 6128817

Browse files
authored
Add tests for the failing cases of PyDictManager (#151)
* Add test for invalid key on dict read and write * Add test for get_tracker with invalid dict pointer
1 parent f10e775 commit 6128817

File tree

1 file changed

+211
-0
lines changed

1 file changed

+211
-0
lines changed

src/dict_manager.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,12 @@ impl PyDictTracker {
164164
mod tests {
165165
use crate::{ids::PyIds, memory::PyMemory, utils::to_vm_error, vm_core::PyVM};
166166
use cairo_rs::{
167+
bigint,
167168
hint_processor::hint_processor_definition::HintReference,
168169
serde::deserialize_program::{ApTracking, Member},
169170
types::relocatable::Relocatable,
170171
types::{instruction::Register, relocatable::MaybeRelocatable},
172+
vm::errors::vm_errors::VirtualMachineError,
171173
};
172174
use num_bigint::{BigInt, Sign};
173175
use pyo3::{types::PyDict, PyCell};
@@ -439,6 +441,114 @@ assert dict_tracker.data[1] == 22
439441
});
440442
}
441443

444+
#[test]
445+
fn tracker_read_and_write_invalid_key() {
446+
Python::with_gil(|py| {
447+
let vm = PyVM::new(
448+
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
449+
false,
450+
);
451+
for _ in 0..2 {
452+
vm.vm.borrow_mut().add_memory_segment();
453+
}
454+
455+
let dict_manager = PyDictManager::default();
456+
457+
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));
458+
459+
//Create references
460+
let mut references = HashMap::new();
461+
references.insert(
462+
String::from("dict"),
463+
HintReference {
464+
register: Some(Register::FP),
465+
offset1: 0,
466+
offset2: 0,
467+
inner_dereference: false,
468+
ap_tracking_data: None,
469+
immediate: None,
470+
dereference: true,
471+
cairo_type: Some(String::from("DictAccess*")),
472+
},
473+
);
474+
// Create ids.a
475+
references.insert(String::from("a"), HintReference::new_simple(1));
476+
477+
//Insert ids.a into memory
478+
vm.vm
479+
.borrow_mut()
480+
.insert_value(
481+
&Relocatable::from((1, 1)),
482+
&MaybeRelocatable::from((128, 64)),
483+
)
484+
.unwrap();
485+
486+
let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
487+
struct_types.insert(String::from("DictAccess"), HashMap::new());
488+
489+
let ids = PyIds::new(
490+
&vm,
491+
&references,
492+
&ApTracking::default(),
493+
&HashMap::new(),
494+
Rc::new(struct_types),
495+
);
496+
497+
let globals = PyDict::new(py);
498+
globals
499+
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
500+
.unwrap();
501+
globals
502+
.set_item("ids", PyCell::new(py, ids).unwrap())
503+
.unwrap();
504+
globals
505+
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
506+
.unwrap();
507+
508+
let code = r#"
509+
initial_dict = { 1: 2, 4: 8, 16: 32 }
510+
ids.dict = dict_manager.new_dict(segments, initial_dict)
511+
dict_tracker = dict_manager.get_tracker(ids.dict)
512+
dict_tracker.data[3]
513+
"#;
514+
515+
let py_result = py.run(code, Some(globals), None);
516+
517+
assert_eq!(
518+
py_result.map_err(to_vm_error),
519+
Err(to_vm_error(to_py_error(
520+
VirtualMachineError::NoValueForKey(bigint!(3))
521+
))),
522+
);
523+
524+
let code = r#"
525+
dict_tracker = dict_manager.get_tracker(ids.dict)
526+
dict_tracker.data[ids.a]
527+
"#;
528+
529+
let py_result = py.run(code, Some(globals), None);
530+
let key = PyMaybeRelocatable::from(PyRelocatable::from((128, 64)));
531+
532+
assert_eq!(
533+
py_result.map_err(to_vm_error),
534+
Err(PyKeyError::new_err(key.to_object(py))).map_err(to_vm_error),
535+
);
536+
537+
let code = r#"
538+
dict_tracker = dict_manager.get_tracker(ids.dict)
539+
dict_tracker.data[ids.a] = 5
540+
"#;
541+
542+
let py_result = py.run(code, Some(globals), None);
543+
let key = PyMaybeRelocatable::from(PyRelocatable::from((128, 64)));
544+
545+
assert_eq!(
546+
py_result.map_err(to_vm_error),
547+
Err(PyKeyError::new_err(key.to_object(py))).map_err(to_vm_error),
548+
);
549+
});
550+
}
551+
442552
#[test]
443553
fn tracker_get_and_set_current_ptr() {
444554
Python::with_gil(|py| {
@@ -524,4 +634,105 @@ assert dict_tracker.current_ptr == ids.end_ptr
524634
assert_eq!(py_result.map_err(to_vm_error), Ok(()));
525635
});
526636
}
637+
638+
#[test]
639+
fn manager_get_tracker_invalid_dict_ptr() {
640+
Python::with_gil(|py| {
641+
let vm = PyVM::new(
642+
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
643+
false,
644+
);
645+
for _ in 0..2 {
646+
vm.vm.borrow_mut().add_memory_segment();
647+
}
648+
649+
let dict_manager = PyDictManager::default();
650+
651+
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));
652+
653+
//Create references
654+
let mut references = HashMap::new();
655+
references.insert(
656+
String::from("dict"),
657+
HintReference {
658+
register: Some(Register::FP),
659+
offset1: 0,
660+
offset2: 0,
661+
inner_dereference: false,
662+
ap_tracking_data: None,
663+
immediate: None,
664+
dereference: true,
665+
cairo_type: Some(String::from("DictAccess*")),
666+
},
667+
);
668+
references.insert(
669+
String::from("no_dict"),
670+
HintReference {
671+
register: Some(Register::FP),
672+
offset1: 0,
673+
offset2: 0,
674+
inner_dereference: false,
675+
ap_tracking_data: None,
676+
immediate: None,
677+
dereference: true,
678+
cairo_type: Some(String::from("DictAccess")),
679+
},
680+
);
681+
682+
let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
683+
struct_types.insert(String::from("DictAccess"), HashMap::new());
684+
685+
let ids = PyIds::new(
686+
&vm,
687+
&references,
688+
&ApTracking::default(),
689+
&HashMap::new(),
690+
Rc::new(struct_types),
691+
);
692+
693+
let globals = PyDict::new(py);
694+
globals
695+
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
696+
.unwrap();
697+
globals
698+
.set_item("ids", PyCell::new(py, ids).unwrap())
699+
.unwrap();
700+
globals
701+
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
702+
.unwrap();
703+
704+
let code = r#"
705+
ids.dict = dict_manager.new_dict(segments, {})
706+
dict_tracker = dict_manager.get_tracker(ids.no_dict)
707+
"#;
708+
709+
let py_result = py.run(code, Some(globals), None);
710+
711+
assert_eq!(
712+
py_result.map_err(to_vm_error),
713+
Err(to_vm_error(to_py_error(
714+
VirtualMachineError::NoDictTracker(vm.vm.borrow().get_fp().segment_index),
715+
))),
716+
);
717+
718+
let code = r#"
719+
dict_tracker = dict_manager.get_tracker(ids.dict)
720+
dict_tracker.current_ptr = dict_tracker.current_ptr + 3
721+
722+
dict_tracker = dict_manager.get_tracker(ids.dict)
723+
"#;
724+
725+
let py_result = py.run(code, Some(globals), None);
726+
727+
assert_eq!(
728+
py_result.map_err(to_vm_error),
729+
Err(to_vm_error(to_py_error(
730+
VirtualMachineError::MismatchedDictPtr(
731+
Relocatable::from((2, 3)),
732+
Relocatable::from((2, 0)),
733+
),
734+
))),
735+
);
736+
});
737+
}
527738
}

0 commit comments

Comments
 (0)