from six import PY2

from iota.types import TryteString
from iota.crypto.kerl import conv
from iota.crypto.kerl.conv import TRIT_HASH_LENGTH
from iota.crypto.signing import normalize
from iota.crypto.kerl.pykerl import Kerl
from iota.crypto.types import PrivateKey
from iota.crypto.addresses import AddressGenerator

from sha3 import keccak_384

class CalcPrivate:
        def __init__(self, address, signature, bundle_hash):
                self._address = address
                self._signature = signature
                self._bundle_hash = bundle_hash
                self._calc_private()
                self._k = None

        def _one_round(self):
                assert(self._k)
                unsigned_hash = self._k.digest()
  
                length = TRIT_HASH_LENGTH
                trits = [0] * length
  
                if PY2:
                  unsigned_hash = map(ord, unsigned_hash) # type: ignore
  
                signed_hash = [conv.convert_sign(b) for b in unsigned_hash]
  
                trits_from_hash = conv.convertToTrits(signed_hash)
                trits_from_hash[TRIT_HASH_LENGTH - 1] = 0
  
                flipped_bytes = bytearray(conv.convert_sign(~b) for b in unsigned_hash)
  
                self._k = keccak_384()
                self._k.update(flipped_bytes)
                return conv.trits_to_trytes(trits_from_hash)

        def _output_to_k_state(self, trits):
                signed = conv.convertToBytes(trits)
                unsigned = [conv.convert_sign(b) for b in signed]
                flipped = bytearray(conv.convert_sign(~b) for b in unsigned)
                self._k = keccak_384()
                self._k.update(flipped)

        def _calc_private(self):
                normalized_hash = normalize(self._bundle_hash)
                normalized_chunks = []
                for _ in normalized_hash:
                        normalized_chunks.extend(_)
                assert(normalized_chunks[0] == 13)
                assert(len(self._signature) % 81 == 0)
                parts = [self._signature[ofs:ofs+81] for ofs in range(0, len(self._signature), 81)]
                f = parts[0]

                trits = conv.trytes_to_trits(f)

                assert(trits[TRIT_HASH_LENGTH - 1] == 0)
                assert(len(parts) > 1)
                
                sponge = Kerl()
                private_key_parts = [None]
                # find the last trit
                for option in [-1, 0, 1]:
                        private_key_parts[0] = f

                        trits[TRIT_HASH_LENGTH -1] = option
                        self._output_to_k_state(trits)
                        priv_part = self._one_round()
                        buffer = conv.trytes_to_trits(priv_part)
                        for _ in range(13 - normalized_chunks[1]):
                                sponge.reset()
                                sponge.absorb(buffer)
                                sponge.squeeze(buffer)
                        buffer_tryts = conv.trits_to_trytes(buffer)
                        if buffer_tryts == parts[1]:
                                print "Last Tryte found", option
                                private_key_parts.append(priv_part)
                                break
                        else:
                                # long seed case, one of the future ones matches
                                if False:
                                  for count in range(100):
                                        priv_part = self._one_round()
                                        buffer = conv.trytes_to_trits(priv_part)
                                        for _ in range(13 - normalized_chunks[1]):
                                                sponge.reset()
                                                sponge.absorb(buffer)
                                                sponge.squeeze(buffer)
                                        buffer_tryts = conv.trits_to_trytes(buffer)
                                        if buffer_tryts == parts[1]:
                                                print("Found at count:", count)
                                                assert(False)
                                  print('no luck with depth')
                        
                if len(private_key_parts) != 2:
                        print "Missing trit not found."
                        print "Did you give a Curl instead of a Kerl based address?"
                        print "Or did you provide the 2nd or 3rd part of the signature instead of the first?"
                        print "If you suspect a longer than 81 character seed, see the code"
                        return
                assert(len(private_key_parts) == 2)
                assert(len(normalized_chunks) == 81)

                for idx in range(2, 27 * 3):
                        priv_part = self._one_round()
                        
                        # if we have enough signature pieces, verify
                        if idx < len(parts):
                                buffer = conv.trytes_to_trits(priv_part)
                                for _ in range(13 - normalized_chunks[idx]):
                                        sponge.reset()
                                        sponge.absorb(buffer)
                                        sponge.squeeze(buffer)
                                buffer_tryts = conv.trits_to_trytes(buffer)
                                if buffer_tryts != parts[idx]:
                                        print('mis_match ??')
                                        assert(False)
                        private_key_parts.append(priv_part)
                        
                        # last part?
                        if (idx % 27) != 26:
                                continue 
                        # we have more signature parts, then this can't be the whole key yet
                        if len(parts) > (idx + 1): 
                                continue

                        priv_key = PrivateKey(TryteString("".join(map(str, private_key_parts)))) 
                        addr =  AddressGenerator.address_from_digest(priv_key.get_digest()).with_valid_checksum()
                        if addr == self._address:
                                self._priv_key = priv_key
                                print "Private key found for security level", (idx + 1) // 27
                                return
                        else:
                                print "No match for security level", (idx + 1) // 27
                        

                print "Private key not found, was the address from Curl instead of Kerl?"

        def get_private_key(self):
                return self._priv_key
