Download this file
DIGITS = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
ZERO = {"0"}
OPERATORS = {"+", "-", "*", "/"}
EQUAL = {'='}

CORRECT = "c"
INCORRECT = "i"
ABSENT = "a"

class Solver:
    def __init__(self, digits=DIGITS, operators=OPERATORS):
        self.digits = digits.copy()
        self.operators = operators.copy()
        self.corrects = [None] * 8
        self.incorrects = [set() for _ in range(8)]
        self.unused = digits | operators

    def all_absent(self, char, expr, result):
        res = True
        for e,r in zip(list(expr), list(result)):
            if e == char:
                res &= (r == ABSENT)
        return res

    def record(self, last_expr, last_result):
        for i,(e,r) in enumerate(zip(list(last_expr), list(last_result))):
            self.unused -= {e}
            if r == CORRECT:
                self.corrects[i] = e
            elif r == ABSENT:
                self.incorrects[i].add(e)
                if self.all_absent(e, last_expr, last_result):
                    if e in self.digits:
                        self.digits -= {e}
                    elif e in self.operators:
                        self.operators -= {e}
            elif r == INCORRECT:
                self.incorrects[i].add(e)
            else:
                # do something about typo
                pass

    def next_set(self, expr, element=[]):
        if len(expr) == 0:
            return self.digits - ZERO
        elif len(expr) == 1:
            return self.digits | self.operators
        elif set(expr) & EQUAL == EQUAL:
            return self.digits
        elif len(expr) == 6 and expr[-1] not in EQUAL:
            return EQUAL
        elif set(expr) & OPERATORS != set() and expr[-1] not in OPERATORS:
            return self.digits | self.operators | EQUAL
        else:
            if expr[-1] in OPERATORS:
                return self.digits - ZERO
            elif len(expr) > 2 and set(expr[-3:]) - DIGITS == set():
                return self.operators
            else:
                return self.digits | self.operators

    def validate(self, expr, res):
        if self.unused == set() and (self.digits | self.operators) - set(expr) != set():
            return False
        if set(expr) & EQUAL == set():
            return False
        if not set(res) <= self.digits:
            return False
        for i,e in enumerate(res):
            pos = len(expr) - len(res) + i
            if self.corrects[pos] is not None and e != self.corrects[pos]:
                return False
            elif e in self.incorrects[pos]:
                return False
        return True

    def score(self, expr):
        # Favour experessions that use many different symbols and use many yet-unused symbols.
        return -(len(self.unused | set(expr)) + len(set(expr)))

    def evaluate(self, expr, solutions):
        try:
            res = eval("".join(expr[:-1]))
        except:
            pass
        else:
            if res >= 0 and int(res) == res:
                res = list(str(int(res)))
                if len(expr) + len(res) == 8:
                    expr += res
                    if self.validate(expr, res):
                        solutions.add("".join(expr))

    def _solve(self, solutions, expr=[]):
        pos = len(expr)
        if pos > 0 and expr[-1] in EQUAL:
            self.evaluate(expr, solutions)
        else:
            next_set = self.next_set(expr) - self.incorrects[pos]
            if self.corrects[pos] is not None:
                next_set &= {self.corrects[pos]}
            for e in next_set:
                self._solve(solutions, expr + [e])

    def _solutions(self):
        solutions = set()
        self._solve(solutions)
        return {(self.score(e), e) for e in solutions}

    def solve(self):
        solutions = list(self._solutions())
        return (sorted(solutions)[0], len(solutions))

    def input(self, prompt=""):
        while True:
            s = input(prompt).strip()
            if len(s) == 8 and set(s) <= { CORRECT, INCORRECT, ABSENT }:
                return s
            else:
                print("Invalid input")

    def run(self, expr="8+7*6=50", count=2562890625):
        if count == 2562890625:
            print("([c]orrect [i]ncorrect [a]bsent")
        if count > 1:
            self.record(expr, self.input(f"{expr} > ").strip())
            try:
                (score, expr), count = self.solve()
            except:
                print("Failed to find a suitable solution")
            else:
                self.run(expr, count)
        else:
            print(expr)


if __name__ == '__main__':
    Solver().run()