Skip to content

Commit

Permalink
Merge pull request #35 from namnc/feat/multidimensional-assignments
Browse files Browse the repository at this point in the history
feat: multidimensional assignments
  • Loading branch information
namnc authored Mar 18, 2024
2 parents dcbad0d + 47d0a60 commit 530613e
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 158 deletions.
20 changes: 20 additions & 0 deletions src/assets/arrayAssignment.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
pragma circom 2.0.0;

template ComponentA () {
signal input in[2][2];
signal output out;

out <== in[0][0] + in[0][1] + in[1][0] + in[1][1];
}

template ComponentB() {
signal input a_in[2][2];
signal output out;

component a = ComponentA();
a.in <== a_in;

out <== a.out;
}

component main = ComponentB();
2 changes: 1 addition & 1 deletion src/assets/circuit.circom
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,4 @@ template network() {
// out <== l2.out;
}

component main = network();
component main = network();
285 changes: 149 additions & 136 deletions src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
use crate::circuit::{AGateType, ArithmeticCircuit};
use crate::program::ProgramError;
use crate::runtime::{
generate_u32, increment_indices, u32_to_access, Context, DataAccess, DataType, Runtime, Signal,
SubAccess, RETURN_VAR,
generate_u32, increment_indices, u32_to_access, Context, DataAccess, DataType, NestedValue,
Runtime, Signal, SubAccess, RETURN_VAR,
};
use circom_circom_algebra::num_traits::ToPrimitive;
use circom_program_structure::ast::{
Expand Down Expand Up @@ -37,16 +37,17 @@ pub fn process_statement(
statement: &Statement,
) -> Result<(), ProgramError> {
match statement {
Statement::Block { stmts, .. } => process_statements(ac, runtime, program_archive, stmts),
Statement::InitializationBlock {
initializations, ..
} => {
for stmt in initializations {
process_statement(ac, runtime, program_archive, stmt)?;
}

Ok(())
}
} => process_statements(ac, runtime, program_archive, initializations),
Statement::Block { stmts, .. } => process_statements(ac, runtime, program_archive, stmts),
Statement::Substitution {
var,
access,
rhe,
op,
..
} => handle_substitution(ac, runtime, program_archive, var, access, rhe, op),
Statement::Declaration {
xtype,
name,
Expand Down Expand Up @@ -95,27 +96,6 @@ pub fn process_statement(

Ok(())
}
Statement::While { cond, stmt, .. } => {
runtime.push_context(true)?;
loop {
let access = process_expression(ac, runtime, program_archive, cond)?;
let result = runtime
.current_context()?
.get_variable_value(&access)?
.ok_or(ProgramError::EmptyDataItem)?;

if result == 0 {
break;
}

runtime.push_context(true)?;
process_statement(ac, runtime, program_archive, stmt)?;
runtime.pop_context(true)?;
}
runtime.pop_context(true)?;

Ok(())
}
Statement::IfThenElse {
cond,
if_case,
Expand Down Expand Up @@ -144,46 +124,24 @@ pub fn process_statement(
Ok(())
}
}
Statement::Substitution {
var,
access,
rhe,
op,
..
} => {
let lh_access = build_access(ac, runtime, program_archive, var, access)?;
let rh_access = process_expression(ac, runtime, program_archive, rhe)?;

let ctx = runtime.current_context()?;
match ctx.get_item_data_type(var)? {
DataType::Signal => {
// Connect the generated gate output to the given signal
let given_output_id = ctx.get_signal_id(&lh_access)?;
let gate_output_id = get_signal_for_access(ac, ctx, &rh_access)?;
Statement::While { cond, stmt, .. } => {
runtime.push_context(true)?;
loop {
let access = process_expression(ac, runtime, program_archive, cond)?;
let result = runtime
.current_context()?
.get_variable_value(&access)?
.ok_or(ProgramError::EmptyDataItem)?;

ac.add_connection(gate_output_id, given_output_id)?;
}
DataType::Variable => {
// Assign the evaluated right-hand side to the left-hand side
let value = ctx.get_variable_value(&rh_access)?;
ctx.set_variable(&lh_access, value)?;
if result == 0 {
break;
}
DataType::Component => match op {
AssignOp::AssignVar => {
// Component assignment
let signal_map = ctx.get_component_map(&rh_access)?;
ctx.set_component(&lh_access, signal_map)?;
}
AssignOp::AssignConstraintSignal => {
// Add connection
let component_signal = ctx.get_component_signal_id(&lh_access)?;
let assigned_signal = get_signal_for_access(ac, ctx, &rh_access)?;

ac.add_connection(assigned_signal, component_signal)?;
}
_ => return Err(ProgramError::OperationNotSupported),
},
runtime.push_context(true)?;
process_statement(ac, runtime, program_archive, stmt)?;
runtime.pop_context(true)?;
}
runtime.pop_context(true)?;

Ok(())
}
Expand All @@ -200,27 +158,91 @@ pub fn process_statement(

Ok(())
}
Statement::MultSubstitution { meta, lhe, op, rhe } => {
println!("Statement not implemented: MultSubstitution");
Ok(())
}
Statement::UnderscoreSubstitution { meta, op, rhe } => {
println!("Statement not implemented: UnderscoreSubstitution");
Ok(())
}
Statement::ConstraintEquality { meta, lhe, rhe } => {
println!("Statement not implemented: ConstraintEquality");
Ok(())
}
Statement::LogCall { meta, args } => {
println!("Statement not implemented: LogCall");
Ok(())
}
Statement::Assert { meta, arg } => {
println!("Statement not implemented: Assert");
Ok(())
_ => Err(ProgramError::StatementNotImplemented),
}
}

/// Handles a substitution statement
fn handle_substitution(
ac: &mut ArithmeticCircuit,
runtime: &mut Runtime,
program_archive: &ProgramArchive,
var: &str,
access: &[Access],
rhe: &Expression,
op: &AssignOp,
) -> Result<(), ProgramError> {
let lh_access = build_access(ac, runtime, program_archive, var, access)?;
let rh_access = process_expression(ac, runtime, program_archive, rhe)?;

let ctx = runtime.current_context()?;
match ctx.get_item_data_type(var)? {
DataType::Variable => {
// Assign the evaluated right-hand side to the left-hand side
let value = ctx.get_variable_value(&rh_access)?;
ctx.set_variable(&lh_access, value)?;
}
DataType::Component => match op {
AssignOp::AssignVar => {
// Component instantiation
let signal_map = ctx.get_component_map(&rh_access)?;
ctx.set_component(&lh_access, signal_map)?;
}
AssignOp::AssignConstraintSignal => {
// Component signal assignment
match ctx.get_component_signal_content(&lh_access)? {
NestedValue::Array(signal) => {
let assigned_signal_array =
match get_signal_content_for_access(ctx, &rh_access)? {
NestedValue::Array(array) => array,
_ => return Err(ProgramError::InvalidDataType),
};

connect_signal_arrays(ac, &signal, &assigned_signal_array)?;
}
NestedValue::Value(_) => {
let component_signal = ctx.get_component_signal_id(&lh_access)?;
let assigned_signal = get_signal_for_access(ac, ctx, &rh_access)?;

ac.add_connection(assigned_signal, component_signal)?;
}
}
}
_ => return Err(ProgramError::OperationNotSupported),
},
DataType::Signal => {
match rhe {
Expression::InfixOp { .. } => {
// Construct the corresponding circuit gate for the given operation
let given_output_id = ctx.get_signal_id(&lh_access)?;
let gate_output_id = get_signal_for_access(ac, ctx, &rh_access)?;

// Connect the generated gate output to the actual signal
ac.add_connection(gate_output_id, given_output_id)?;
}
Expression::Variable { .. } => match ctx.get_signal_content(&lh_access)? {
// This corresponds to
NestedValue::Array(signal) => {
let assigned_signal_array =
match get_signal_content_for_access(ctx, &rh_access)? {
NestedValue::Array(array) => array,
_ => return Err(ProgramError::InvalidDataType),
};

connect_signal_arrays(ac, &signal, &assigned_signal_array)?;
}
NestedValue::Value(signal_id) => {
let gate_output_id = get_signal_for_access(ac, ctx, &rh_access)?;

ac.add_connection(gate_output_id, signal_id)?;
}
},
_ => {}
}
}
}

Ok(())
}

/// Processes an expression and returns an access to the result.
Expand Down Expand Up @@ -250,54 +272,7 @@ pub fn process_expression(
Expression::Variable { name, access, .. } => {
build_access(ac, runtime, program_archive, name, access)
}
Expression::PrefixOp {
meta,
prefix_op,
rhe,
} => {
println!("Expression not implemented:PrefixOp");
Ok(DataAccess::new("", vec![]))
}
Expression::InlineSwitchOp {
meta,
cond,
if_true,
if_false,
} => {
println!("Expression not implemented:InlineSwitchOp");
Ok(DataAccess::new("", vec![]))
}
Expression::ParallelOp { meta, rhe } => {
println!("Expression not implemented:ParallelOp");
Ok(DataAccess::new("", vec![]))
}
Expression::AnonymousComp {
meta,
id,
is_parallel,
params,
signals,
names,
} => {
println!("Expression not implemented:AnonymousComp");
Ok(DataAccess::new("", vec![]))
}
Expression::ArrayInLine { meta, values } => {
println!("Expression not implemented:ArrayInLine");
Ok(DataAccess::new("", vec![]))
}
Expression::Tuple { meta, values } => {
println!("Expression not implemented:Tuple");
Ok(DataAccess::new("", vec![]))
}
Expression::UniformArray {
meta,
value,
dimension,
} => {
println!("Expression not implemented: UniformArray");
Ok(DataAccess::new("", vec![]))
}
_ => Err(ProgramError::ExpressionNotImplemented),
}
}

Expand Down Expand Up @@ -470,8 +445,46 @@ fn get_signal_for_access(
}
}

/// Returns the content of a signal for a given access
fn get_signal_content_for_access(
ctx: &Context,
access: &DataAccess,
) -> Result<NestedValue<u32>, ProgramError> {
match ctx.get_item_data_type(&access.get_name())? {
DataType::Signal => Ok(ctx.get_signal_content(access)?),
DataType::Component => Ok(ctx.get_component_signal_content(access)?),
_ => Err(ProgramError::InvalidDataType),
}
}

/// Connects two composed signals
fn connect_signal_arrays(
ac: &mut ArithmeticCircuit,
a: &Vec<NestedValue<u32>>,
b: &Vec<NestedValue<u32>>,
) -> Result<(), ProgramError> {
// Verify that the arrays have the same length
if a.len() != b.len() {
return Err(ProgramError::InvalidDataType);
}

for (a, b) in a.iter().zip(b.iter()) {
match (a, b) {
(NestedValue::Value(a), NestedValue::Value(b)) => {
ac.add_connection(*a, *b)?;
}
(NestedValue::Array(a), NestedValue::Array(b)) => {
connect_signal_arrays(ac, a, b)?;
}
_ => return Err(ProgramError::InvalidDataType),
}
}

Ok(())
}

/// Builds a DataAccess from an Access array
pub fn build_access(
fn build_access(
ac: &mut ArithmeticCircuit,
runtime: &mut Runtime,
program_archive: &ProgramArchive,
Expand Down Expand Up @@ -500,7 +513,7 @@ pub fn build_access(
}

/// Executes an operation on two u32 values, performing the specified arithmetic or logical computation.
pub fn execute_op(lhs: u32, rhs: u32, op: &ExpressionInfixOpcode) -> Result<u32, ProgramError> {
fn execute_op(lhs: u32, rhs: u32, op: &ExpressionInfixOpcode) -> Result<u32, ProgramError> {
let res = match op {
ExpressionInfixOpcode::Mul => lhs * rhs,
ExpressionInfixOpcode::Div => {
Expand Down
4 changes: 4 additions & 0 deletions src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub enum ProgramError {
CircuitError(CircuitError),
#[error("Empty data item")]
EmptyDataItem,
#[error("Expression not implemented")]
ExpressionNotImplemented,
#[error("Input initialization error")]
InputInitializationError,
#[error("Invalid data type")]
Expand All @@ -58,6 +60,8 @@ pub enum ProgramError {
ParsingError,
#[error("Runtime error: {0}")]
RuntimeError(RuntimeError),
#[error("Statement not implemented")]
StatementNotImplemented,
#[error("Undefined function or template")]
UndefinedFunctionOrTemplate,
}
Loading

0 comments on commit 530613e

Please sign in to comment.