"""
Durus SQLite storage engine.
Based on code from file_storage.py, Copyright (c) Corporation for National Research Initiatives 2006. 

Author: Peter Wilkinson, pw@thirdfloor.com.au, Thirdfloor Software Works Pty. Ltd.


Single pass packer, not incremental
"""

from durus.serialize import unpack_record, split_oids, extract_class_name
from durus.connection import ROOT_OID
from durus.utils import p64, u64
from durus.storage import Storage
try:
    import apsw
except ImportError:
    print "************************************************"
    print "Required module, apsw, missing."
    print "Available at http://www.rogerbinns.com/apsw.html"
    print "************************************************"
    raise
    
import time
from sets import Set

def mytrace(statement, bindings):
    "Called just before executing each statement"
    print "SQL:",statement
    if bindings:
        print "Bindings:",bindings
    return True  # if you return False then execution is aborted

class SQLiteStorage (Storage):
    """
    """
    
    _PACK_INCREMENT = 500 # number of records to pack before yielding
    
    def __init__(self, filename=None, auto_vacuum=1, default_cache_size=20000, fullfsync=0,
            synchronous=0, debug=False):
        """(filename:str=None, auto_vacuum:int=1, default_cache_size:int=20000, fullfsync:int=0,
            synchronous:int=0)
        If filename is empty (or None), an in memory database will be used.
        auto_vacuum, default_cache_size, fullfsync, synchronous are SQLite PRAGMA options.
        """
        if not filename:
            filename = ':memory:'

        self.con = apsw.Connection(filename)
        self.con.setbusytimeout(5000)
        self.cursor = self.con.cursor()
        if debug:
            self.cursor.setexectrace(mytrace)
        self.pack_extra = None
        self.table = 'records'
        
        # check the tables exists
        create_record_sql = "CREATE TABLE %s (oid INTEGER PRIMARY KEY, pack INTEGER DEFAULT 0, " \
                "record BLOB, trans_id INTEGER)" % self.table
        create_history_sql = "CREATE TABLE trans_history (trans_id INTEGER PRIMARY KEY, " \
                "modified TEXT DEFAULT (strftime('%%Y-%%m-%%d %%H:%%M:%%f','now')))"
        try:
            self.cursor.execute("SELECT * FROM sqlite_master WHERE sql = '%s'" % create_record_sql).next()
            self.cursor.execute("SELECT * FROM sqlite_master WHERE sql = '%s'" % create_history_sql.replace("'", "''")).next()
        except (StopIteration, apsw.SQLError), e:
            print "Create tables."

            self.cursor.execute("PRAGMA auto_vacuum = %d" % auto_vacuum)
            self.cursor.execute("PRAGMA default_cache_size = %d" % default_cache_size)
            self.cursor.execute("PRAGMA fullfsync = %d" % fullfsync)
            self.cursor.execute("PRAGMA synchronous = %d" % synchronous)
            
            self.cursor.execute(create_record_sql)
            self.cursor.execute(create_history_sql)

        self.oid = self._select_maxoid()
        self.transaction = {}
       
        if not self.oid:
            self.oid = -1

    def get_size(self):
        #print "get_size()"
        return self.cursor.execute("SELECT count(oid) from %s" % self.table).next()[0]
        
    def new_oid(self):
        self.oid += 1
        return p64(self.oid)

    def load(self, oid):
        if self.cursor is None:
            raise IOError, 'storage is closed'        
        try:
            record = self.cursor.execute("SELECT record from %s WHERE oid = ?" % self.table, \
                    (u64(oid),)).next()[0]
            return str(record)
        except StopIteration:
            raise KeyError(u64(oid))
        
    def begin(self):
        """ begin """
        pass
        
    def store(self, oid, record):
        self.transaction[oid] = record                

    def end(self, handle_invalidations=None):
        """ commit """
        if self.cursor is None:
            raise IOError, 'storage is closed'        

        self.cursor.execute("BEGIN")
        # update transaction id
        self.cursor.execute("INSERT INTO trans_history (modified) VALUES (?)", (time.time(),))
        trans_id = self.cursor.execute("SELECT last_insert_rowid()").next()[0]
        for oid, record in self.transaction.iteritems():
            self.cursor.execute("INSERT OR REPLACE INTO %s (oid, record, trans_id) VALUES (?, ?, ?)" % self.table, (u64(oid),
                    buffer(record), trans_id))
                            
        self.cursor.execute("COMMIT")

        if self.pack_extra is not None:
            self.pack_extra.extend(self.transaction.keys())
        self.transaction = {}                    
        
    def sync(self):
        return []

    def gen_oid_record(self):
        for row in self.cursor.execute("SELECT oid, record from %s" % self.table):
            yield row[0], str(row[1])
       
    def _select_maxoid(self):
        try:
            return self.cursor.execute("SELECT MAX(oid) from %s" % self.table).next()[0]
        except TypeError:
            return 0

    def close(self):
        if self.cursor is not None:
            self.cursor = None
        if self.con:
            self.con.close()
            
    def _packer(self):
        """ Single pass packer """
        #print "_packer"
        if self.cursor is None:
            raise IOError, 'storage is closed'        

        assert not self.transaction

        self.cursor.execute("BEGIN")
        todo = [ROOT_OID]
        seen = Set()
        while todo:
            oid = todo.pop()
            if oid in seen:
                continue
            seen.add(oid)
            record = self.load(oid)
            record_oid, data, refdata = unpack_record(record)
            assert oid == record_oid
            todo.extend(split_oids(refdata))
            self.cursor.execute("UPDATE %s SET pack = 1 WHERE oid = ?" % self.table, (u64(oid),))
        
        #oids = [(u64(oid),) for oid in gen_reachable_records()]
        # mark to keep
        #print "Total reachable", len(oids)
#         print "=1", self.cursor.execute("select count(*) from records where pack = 1").next()[0]
#         print "<>1", self.cursor.execute("select count(*) from records where pack <> 1").next()[0]
#         print "=0", self.cursor.execute("select count(*) from records where pack = 0").next()[0]
#         print "<>0", self.cursor.execute("select count(*) from records where pack <> 0").next()[0]
#         print "not null", self.cursor.execute("select count(*) from records where pack is not null").next()[0]
#         print "null", self.cursor.execute("select count(*) from records where pack is null").next()[0]
        
        # delete
        self.cursor.execute("DELETE FROM %s WHERE pack = 0" % self.table)
        # mark all has possibly not reachable, they will be marked as reachable next pack
        self.cursor.execute("UPDATE %s SET pack = 0" % self.table)
        self.cursor.execute("COMMIT")               
        yield None
        
    def get_packer(self):
        """Return an incremental packer (a generator).  Each time next() is
        called, up to _PACK_INCREMENT records will be packed.  Note that the
        generator must be exhausted before calling get_packer() again.
        """
        if self.cursor is None:
            raise IOError, 'storage is closed'        
        
        assert not self.transaction
        assert self.pack_extra is None
        self.pack_extra = []
        return self._packer()

    def pack(self):
        for z in self.get_packer():
            pass
