
% assume existence of
%        def(threadcnt, ThreadCnt)
%        def(varcnt, VarCnt)
%        def(qsize, Qsize)  -- needed for checking whether Q is full
% given in the file containing the input program.

% the input program provides the following:
%   instr(PC, Instr)   -- the program for all the threads
%----------------------------------------------------------------

:- import length/2, append/3, member/2 from basics.

:- table reach/4.

reach(Thds1, MM1, Thds2, MM2) :-
	trans(Thds1, MM1, Thds2, MM2).

reach(Thds1, MM1, Thds2, MM2) :-
	trans(Thds1, MM1, Thds3, MM3),
	reach(Thds3, MM3, Thds2, MM2).

%-------------------------------------------------

trans([T1|Rest], MM1, [T2|Rest], MM2) :-
	trans_one(T1, MM1, T2, MM2).

trans([T|Rest1], MM1, [T|Rest2], MM2) :-
	trans(Rest1, MM1, Rest2, MM2).

trans_one(T1, MM1, T2, MM2) :-
	pgm_trans(T1, MM1, T2, MM2).

trans_one(T1, MM1, T2, MM2) :-
        %not(pgm_trans(T1, MM1, _, _)),
	pick_sharedVar(Var),
	jvm_trans(Var, T1, MM1, T2, MM2).


%----------------------------------------------- 

jvm_trans(Var, Th1, MM1, Th2, MM2) :-
        store_trans(Var, Th1, MM1, Th2, MM2).
jvm_trans(Var, Th1, MM1, Th2, MM2) :-
        load_trans(Var, Th1, MM1, Th2, MM2).
jvm_trans(Var, Th1, MM1, Th2, MM2) :-
        read_trans(Var, Th1, MM1, Th2, MM2).
jvm_trans(Var, Th1, MM1, Th2, MM2) :-
        write_trans(Var, Th1, MM1, Th2, MM2).

%------------------------------------------------------


pgm_trans(Th1, MM1, Th2, MM2) :- 
	instr_fetch(Th1, Instr), 
	instr_exec(Instr, Th1, MM1, Th2, MM2).

instr_fetch(Th, Instr) :-
	Th = (PC, _), instr(PC, Instr).

% If PC indicates that the thread has terminated,
% i.e.   PC = (Thread_id, end)
% Then instr/2 and hence instr_fetch/2 fails.

instr_exec(use(Var), Th1, MM1, Th2, MM2) :-
	use_trans(Var, Th1, MM1, Th2, MM2).

instr_exec(assign(Var, Expr), Th1, MM1, Th2, MM2) :-
	eval_expr(Expr, Th1, MM1, Val),
	assign_trans(Var, Val, Th1, MM1, Th2, MM2).

instr_exec(lock, Th1, MM1, Th2, MM2) :-
	lock_trans(Th1, MM1, Th2, MM2).

instr_exec(unlock, Th1, MM1, Th2, MM2) :-
	unlock_trans(Th1, MM1, Th2, MM2).

% "use and assert" actions --
%       use followed by checks on current state

instr_exec(assert(Op, use(Var), Expr), Th1, MM1, Th2, MM2) :-
	use_trans(Var, Th1, MM1, Th2, MM2),
	eval_expr(reg(Var), Th2, MM2, VarVal),
	eval_expr(Expr, Th2, MM2, Value),
	( (Op == eq, VarVal = Value) ; % only eq/neq checks for now
	  (Op == neq, VarVal \= Value)
        ).

% new instruction (flush) 
% to observe the effect of Fence instructions inserted by JMM

instr_exec(flush, Th1, MM1, Th2, MM2) :-
	flush_trans(Th1, MM1, Th2, MM2).

%--------------------------------------------------------------

% local state of any thread is defined as:
%     (PC, Cache, Registers, Read-queue, Write-queue)
% "use" action reads from Cache and writes to Registers
% "assign" action writes to Cache

% PROGRAM  ACTIONS

flush_trans(Th1, MM, Th2, MM) :- 
	Th1 = (PC, Cache, Regs, Rdq, Wrq),
        next_instr(PC, NPC),
	Th2 = (NPC, Cache, Regs, Rdq, Wrq),
	not(member((_,true,_), Cache)), % no dirty cache line
	all_same([], Wrq).


use_trans(Var, Th1, MM, Th2, MM) :-
	Th1 = (PC, Cache, Regs1, Rdq, Wrq),
	read_array(Cache, Var, (Rval,_Dirty,Stale)),
	Stale == false,                     % GUARD CONDITIONs
        write_array(Regs1, Var, Rval, Regs2),
	next_instr(PC, NPC),
	Th2 = (NPC, Cache, Regs2, Rdq, Wrq).


assign_trans(Var, Val, Th1, MM, Th2, MM) :-
	Th1 = (PC, Cache1, Regs, Rdq, Wrq),
	read_array(Rdq, Var, []),            % GUARD CONDITIONs
	next_instr(PC, NPC),
	write_array(Cache1, Var, (Val,true,false), Cache2),
	Th2 = (NPC, Cache2, Regs, Rdq, Wrq).


lock_trans(Th1, MM1, Th2, MM2) :-
	Th1 = (PC, Cache1, Regs, Rdq, Wrq),
	PC = (Thread_id, _Local_PC),
	MM1 = (Memory, Lock),
	all_same(Thread_id, Lock),           % GUARD CONDITIONs
	all_same([], Rdq), 		     % GUARD CONDITIONs
	all_same((_,false,_), Cache1),       % GUARD CONDITIONs
	MM2 = (Memory, [Thread_id|Lock]),
	set_stale(Cache1, Cache2),
	next_instr(PC, NPC),
	Th2 = (NPC, Cache2, Regs, Rdq, Wrq).

unlock_trans(Th1, MM1, Th2, MM2) :-
	Th1 = (PC, Cache, Regs, Rdq, Wrq),
	MM1 = (Memory, Lock1),
	PC = (Thread_id, _Local_PC),
	member(Thread_id, Lock1),           % GUARD CONDITIONs
	all_same([], Wrq),                  % GUARD CONDITIONs
	all_same((_,false,_), Cache),       % GUARD CONDITIONs
	Lock1 = [_|Lock2],
	MM2 = (Memory, Lock2),
	next_instr(PC, NPC),
	Th2 = (NPC, Cache, Regs, Rdq, Wrq).


% Platform ACTIONS

load_trans(Var, Th1, MM, Th2, MM) :-
	Th1 = (PC, Cache1, Regs, Rdq1, Wrq),
	read_array(Rdq1, Var, Read_queue1),
	read_array(Cache1, Var, (_Val,Dirty,true)),
	Dirty == false, not(Read_queue1 == []),  % GUARD CONDITIONs
	dequeue(Read_queue1, Qval, Read_queue2),
	write_array(Rdq1, Var, Read_queue2, Rdq2),
	write_array(Cache1, Var, (Qval,Dirty,false), Cache2),
	Th2 = (PC, Cache2, Regs, Rdq2, Wrq).



store_trans(Var, Th1, MM, Th2, MM) :-
	Th1 = (PC, Cache1, Regs, Rdq, Wrq1),
	read_array(Cache1, Var, (RVal, Dirty, Stale)),
	read_array(Wrq1, Var, Write_queue1),
 	read_array(Rdq, Var, []),                  % GUARD CONDITIONs
	Dirty == true, not(full(Write_queue1)),    % GUARD CONDITIONs
	enqueue(Write_queue1, RVal, Write_queue2),
	write_array(Wrq1, Var, Write_queue2, Wrq2),
	write_array(Cache1, Var, (RVal,false,Stale), Cache2),
	Th2 = (PC, Cache2, Regs, Rdq, Wrq2).


read_trans(Var, Th1, MM, Th2, MM) :-
	Th1 = (PC, Cache, Regs, Rdq1, Wrq),
	MM  = (Memory, _Lock),
	read_array(Rdq1, Var, Read_queue1),	
        read_array(Memory, Var, MVal),
        memory_traffic_check(Th1, Var,MM, MVal),
	read_array(Wrq, Var, []),                  % GUARD CONDITIONs
	not(full(Read_queue1)),                    % GUARD CONDITIONs
	read_array(Cache, Var, (_X,false,true)),   % GUARD CONDITIONs
        %write_array(Cache1, Var, (X,false,toload), Cache2),
	enqueue(Read_queue1, MVal, Read_queue2),
	write_array(Rdq1, Var, Read_queue2, Rdq2),
	Th2 = (PC, Cache, Regs, Rdq2, Wrq).



write_trans(Var, Th1, MM1, Th2, MM2) :-
	Th1 = (PC, Cache, Regs, Rdq, Wrq1),
	MM1 = (Memory1, Lock),
	read_array(Wrq1, Var, Write_queue1),
	not(Write_queue1 == []),
	dequeue(Write_queue1, X, Write_queue2),
	write_array(Memory1, Var, X, Memory2),
	write_array(Wrq1, Var, Write_queue2, Wrq2),
	MM2 = (Memory2, Lock),
	Th2 = (PC, Cache, Regs, Rdq, Wrq2).



%*****************************************************************

read_array([Val|_], 0, Val).
read_array([_|Rest], N, Val) :- 
	N > 0, N1 is N-1, read_array(Rest, N1, Val).

write_array([_Old|Rest], 0, New, [New|Rest]).
write_array([H|Rest1], N, New, [H|Rest2]) :-
	   N > 0, N1 is N-1,
	   write_array(Rest1, N1, New, Rest2). 

all_same(_Elem, []).
all_same(Elem, [X|List]) :- 
	copy_term(Elem, Elem1),
        X = Elem1, all_same(Elem, List).

% special routine : to set all stale bits

set_stale([], []).
set_stale([(RV,D,_S)|Rest1], [(RV,D,true)|Rest2]) :-
	set_stale(Rest1, Rest2).

pick_sharedVar(Var) :- 
	def(varcnt, VarCnt), gen_var(0, VarCnt, Var).

gen_var(Lo, Hi, Lo) :- Lo < Hi.
	
gen_var(Lo, Hi, X) :-
	Lo1 is Lo+1, Lo1 < Hi, gen_var(Lo1, Hi, X).

%------------------------------------------------------

eval_expr(Expr, Th, MM, Val) :- 
	Th = (_, _, Regs, _, _),
	eval1(Expr, Regs, MM, Val).

% for now, consider only register/immediate values (only integers).

eval1(reg(N), Regs, _MM, Val) :- read_array(Regs, N, Val).

eval1(imm(Val), _Regs, _MM, Val).   % immediate values.

eval1(mem(Addr), _Regs, MM, Val) :-  
    MM = (Mem, _Lock),
    read_array(Mem, Addr, Val).

eval1(mem(Base, Offset), Regs, MM, Val) :-
        eval1(Offset, Regs, MM, N),  % array accesses with
        Addr is Base + N,            % direct/indirect offsets
	MM = (Mem, _Lock),
        read_array(Mem, Addr, Val).

%-------------------------------------------
% Queue operations -- trivial with Prolog lists

enqueue(Q1, X, Q2) :- append(Q1, [X], Q2).

dequeue(Q1, X, Q2) :- Q1 = [X|Q2].

full(Q) :- def(qsize, Max), length(Q, N), N >= Max.

%---------------------------------------------------

 % Check to reduce memory traffic
 %  checks whether the next pgm action is a use

memory_traffic_check(Th, Var, _MM, _MVal) :-
	instr_fetch(Th, use(Var)).

memory_traffic_check(Th, Var, MM, MVal) :-
	instr_fetch(Th, assert(eq, use(Var), Expr)),
	eval_expr(Expr, Th, MM, Val), MVal = Val.

memory_traffic_check(Th, Var, MM, MVal) :-
	instr_fetch(Th, assert(neq, use(Var), Expr)),
	eval_expr(Expr, Th, MM, Val), MVal \= Val.

%-------------------------------------------

init_state(T, M) :- init_mem(M), init_threads(T).

init_threads(Threads) :- 
	def(threadcnt, ThreadCnt), 
	def(varcnt, VarCnt), 
	init_threads(ThreadCnt, VarCnt, Threads).

init_threads(0, _VC, []).

init_threads(TC, VC, Threads) :-
	TC > 0, TC1 is TC - 1,
	init_one_thread(TC1, VC, Thd),
	init_threads(TC1, VC, Rest),
	append(Rest, [Thd], Threads).


init_one_thread(ThreadId, VC, Thd) :-
	PC = (ThreadId, 0),
	init_cache(VC, Cache), 
	init_array(VC, Regs),
	gen_empty(VC, Queues),
	Thd = (PC, Cache, Regs, Queues, Queues).

init_cache(0, []).
init_cache(N, [(0,false,true)|Rest]) :-
	N > 0, N1 is N-1, init_cache(N1, Rest).

init_array(0, []).
init_array(N, [0|Rest]) :- 
	N > 0, N1 is N-1, init_array(N1, Rest).

gen_empty(0, []).
gen_empty(N, [[]|Rest]) :-
	N > 0, N1 is N-1, gen_empty(N1, Rest).

% init_mem/1 will be given in the program file

%----------------------------------------------------------

final_state(Thds, MM) :- 
	def(threadcnt, N), final_threads(N, Thds), final_mem(MM).


final_threads(0, []).
final_threads(N, Thds) :-
        N > 0, N1 is N -1,
	final_thd_one(N1, T), 
        final_threads(N1, Rest),
        append(Rest, [T], Thds).

final_thd_one(N, T) :-
	T = (PC, Cache, _Regs, Rdq, Wrq),
        def(varcnt, V), 
	gen_empty(V, Rdq), 
        gen_empty(V, Wrq),
        final_cache(V, Cache),
        %final_regs(V, Regs),
        PC = (N, end).

final_cache(0, []).
final_cache(N, [C|Rest]):-
	N > 0, C = (_Val, Dirty, _Stale), 
        %basetype(Val), 
        Dirty = false, 
        %Stale = false,  
        N1 is N -1, final_cache(N1, Rest).
                             % stale values at the end are allowed.
                             % Some shared var. may not have been 
                             % read/written in certain threads.

% basetype/1 to be given in program file.

final_regs(0, []).
final_regs(N, [Val|Rest]):-
	N > 0, N1 is N-1,
	basetype(Val),
	final_regs(N1, Rest).


%----------------------------

ans(X, Y) :- init_state(T, M), 
             final_state(X, Y),
             reach(T, M, X, Y).

write_ans(X, Y):- 
	writeln('Threads= '),
	write_list(X),
        write('Mem = '), writeln(Y), nl.

write_list([]).
write_list([X|Rest]) :- writeln(X), write_list(Rest).

%----------------------------------------

next_instr((T,N), (T, X)):-
     M is N + 1,
     (instr((T,M), _) -> X = M
                      ;  X = end
     ).

