Download this file
DIGITS = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
ZERO = {"0"}
OPERATORS = {"+", "-", "*", "/"}
EQUAL = {'='}
CORRECT = "c"
INCORRECT = "i"
ABSENT = "a"
class Solver:
def __init__(self, digits=DIGITS, operators=OPERATORS):
self.digits = digits.copy()
self.operators = operators.copy()
self.corrects = [None] * 8
self.incorrects = [set() for _ in range(8)]
self.unused = digits | operators
def all_absent(self, char, expr, result):
res = True
for e,r in zip(list(expr), list(result)):
if e == char:
res &= (r == ABSENT)
return res
def record(self, last_expr, last_result):
for i,(e,r) in enumerate(zip(list(last_expr), list(last_result))):
self.unused -= {e}
if r == CORRECT:
self.corrects[i] = e
elif r == ABSENT:
self.incorrects[i].add(e)
if self.all_absent(e, last_expr, last_result):
if e in self.digits:
self.digits -= {e}
elif e in self.operators:
self.operators -= {e}
elif r == INCORRECT:
self.incorrects[i].add(e)
else:
# do something about typo
pass
def next_set(self, expr, element=[]):
if len(expr) == 0:
return self.digits - ZERO
elif len(expr) == 1:
return self.digits | self.operators
elif set(expr) & EQUAL == EQUAL:
return self.digits
elif len(expr) == 6 and expr[-1] not in EQUAL:
return EQUAL
elif set(expr) & OPERATORS != set() and expr[-1] not in OPERATORS:
return self.digits | self.operators | EQUAL
else:
if expr[-1] in OPERATORS:
return self.digits - ZERO
elif len(expr) > 2 and set(expr[-3:]) - DIGITS == set():
return self.operators
else:
return self.digits | self.operators
def validate(self, expr, res):
if self.unused == set() and (self.digits | self.operators) - set(expr) != set():
return False
if set(expr) & EQUAL == set():
return False
if not set(res) <= self.digits:
return False
for i,e in enumerate(res):
pos = len(expr) - len(res) + i
if self.corrects[pos] is not None and e != self.corrects[pos]:
return False
elif e in self.incorrects[pos]:
return False
return True
def score(self, expr):
# Favour experessions that use many different symbols and use many yet-unused symbols.
return -(len(self.unused | set(expr)) + len(set(expr)))
def evaluate(self, expr, solutions):
try:
res = eval("".join(expr[:-1]))
except:
pass
else:
if res >= 0 and int(res) == res:
res = list(str(int(res)))
if len(expr) + len(res) == 8:
expr += res
if self.validate(expr, res):
solutions.add("".join(expr))
def _solve(self, solutions, expr=[]):
pos = len(expr)
if pos > 0 and expr[-1] in EQUAL:
self.evaluate(expr, solutions)
else:
next_set = self.next_set(expr) - self.incorrects[pos]
if self.corrects[pos] is not None:
next_set &= {self.corrects[pos]}
for e in next_set:
self._solve(solutions, expr + [e])
def _solutions(self):
solutions = set()
self._solve(solutions)
return {(self.score(e), e) for e in solutions}
def solve(self):
solutions = list(self._solutions())
return (sorted(solutions)[0], len(solutions))
def input(self, prompt=""):
while True:
s = input(prompt).strip()
if len(s) == 8 and set(s) <= { CORRECT, INCORRECT, ABSENT }:
return s
else:
print("Invalid input")
def run(self, expr="8+7*6=50", count=2562890625):
if count == 2562890625:
print("([c]orrect [i]ncorrect [a]bsent")
if count > 1:
self.record(expr, self.input(f"{expr} > ").strip())
try:
(score, expr), count = self.solve()
except:
print("Failed to find a suitable solution")
else:
self.run(expr, count)
else:
print(expr)
if __name__ == '__main__':
Solver().run()