(** Program optimizer. *)

(*
    il4c  --  Compiler for the IL4 Lisp-ahtava langauge
    Copyright (C) 2007 Jere Sanisalo

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

include Program

(** Goes through the function code and removes unneeded nops. *)
let remove_unnecessary_nops func =
	(* Tests if a statement is not a nop. *)
	let not_nop stmt =
		match stmt with
		| Nop _ -> false
		| _ -> true
	in

	(* Recursively clean a statement. *)
	let rec clean stmt =
		(* Recurse to children. *)
		let stmt =
			match stmt with
			| Nop r -> Nop r
			| Block (stmts,r) -> Block (List.map clean stmts,r)
			| ConstInt (v,r) -> ConstInt (v,r)
			| ConstFloat (v,r) -> ConstFloat (v,r)
			| ConstRawArray (v,r) -> ConstRawArray (v,r)
			| GetGlobal (n,r) -> GetGlobal (n,r)
			| GetLocal (n,r) -> GetLocal (n,r)
			| GetParam (n,r) -> GetParam (n,r)
			| GetExtSymbol (n,r) -> GetExtSymbol (n,r)
			| SetGlobal (n,s,r) -> SetGlobal (n,clean s,r)
			| SetLocal (n,s,r) -> SetLocal (n,clean s,r)
			| SetParam (n,s,r) -> SetParam (n,clean s,r)
			| Call (n,stmts,r) -> Call (n,List.map clean stmts,r)
			| CallAsm (n,stmts,r) -> CallAsm (n,List.map clean stmts,r)
			| CallC (ct,stmts,r) -> CallC (ct,List.map clean stmts,r)
			| If (s1,s2,s3,r) -> If (clean s1, clean s2, clean s3, r)
			| While (s1,s2,r) -> While (clean s1, clean s2, r)
			| Break r -> Break r
			| Continue r -> Continue r
			| Return (s,r) -> Return (clean s,r)
		in

		(* Clean this node. *)
		match stmt with
		| Block (stmts,r) -> Block (List.filter not_nop stmts,r)
		| s -> s
	in

	(* Clean and return. *)
	let new_code = clean func.func_code in
	{ func with func_code = new_code }

(** Removes all globals/functions from the program that are not accessed. *)
let remove_unneeded orig_prg =
	(* Collects the given function. *)
	let rec collect_func prg fname =
		(* Collect the function if not already collected. *)
		let try_collect_fun prg fname =
			if List.mem_assoc fname prg.prg_fun_list then
				prg
			else
				collect_func prg fname
		in

		(* Collect the assembly function if not already collected. *)
		let try_collect_asmfun prg fname =
			if List.mem_assoc fname prg.prg_asmfun_list then
				prg
			else
				let asm_fun = List.assoc fname orig_prg.prg_asmfun_list in
				{ prg with prg_asmfun_list = (fname,asm_fun) :: prg.prg_asmfun_list }
		in

		(* Collect the global if not already collected. *)
		let try_collect_global prg gname =
			if List.mem_assoc gname prg.prg_globals then
				prg
			else
				let glb = List.assoc gname orig_prg.prg_globals in
				{ prg with prg_globals = (gname,glb) :: prg.prg_globals }
		in

		(* Clean the function and add it to the program. *)
		let func = List.assoc fname orig_prg.prg_fun_list in
		let func = remove_unnecessary_nops func in
		let prg = { prg with prg_fun_list = (fname,func) :: prg.prg_fun_list } in

		(* Collect the statements. *)
		let rec collect_stmt prg stmt =
			match stmt with
			| Block (stmts,_) -> List.fold_left collect_stmt prg stmts
			| GetGlobal (v,_) -> try_collect_global prg v
			| SetGlobal (v,s,_) -> collect_stmt (try_collect_global prg v) s
			| SetLocal (_,s,_) -> collect_stmt prg s
			| SetParam (_,s,_) -> collect_stmt prg s
			| Call (v,stmts,_) -> List.fold_left collect_stmt (try_collect_fun prg v) stmts
			| CallAsm (v,stmts,_) -> List.fold_left collect_stmt (try_collect_asmfun prg v) stmts
			| CallC (_,stmts,_) -> List.fold_left collect_stmt prg stmts
			| If (s1,s2,s3,_) -> List.fold_left collect_stmt prg [s1; s2; s3]
			| While (s1,s2,_) -> List.fold_left collect_stmt prg [s1; s2]
			| Return (s,_) -> collect_stmt prg s
			| _ -> prg
		in
		collect_stmt prg func.func_code
	in

	(* Start with a fresh program. *)
	let new_prg =
	{ orig_prg with 
		prg_constants = [];
		prg_globals = [];
		prg_asmfun_list = [];
		prg_fun_list = []}
	in

	(* Start from the main function. *)
	if not (List.mem_assoc "main" orig_prg.prg_fun_list) then
		failwith "The entry point function 'main' not found in the program.";
	collect_func new_prg "main"
